1
0
forked from IPKM/nmreval

Merge branch 'fit_constraints'

# Conflicts:
#	src/gui_qt/main/management.py
This commit is contained in:
Dominik Demuth 2023-09-19 12:39:32 +02:00
commit 04037d6b4d
18 changed files with 574 additions and 619 deletions

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'resources/_ui/fitmodelwidget.ui'
# Form implementation generated from reading ui file 'src/resources/_ui/fitmodelwidget.ui'
#
# Created by: PyQt5 UI code generator 5.12.3
# Created by: PyQt5 UI code generator 5.15.9
#
# WARNING! All changes made in this file will be lost!
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
from PyQt5 import QtCore, QtGui, QtWidgets
@ -13,7 +14,7 @@ from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_FitParameter(object):
def setupUi(self, FitParameter):
FitParameter.setObjectName("FitParameter")
FitParameter.resize(365, 78)
FitParameter.resize(365, 66)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
@ -36,7 +37,7 @@ class Ui_FitParameter(object):
self.parametername.setObjectName("parametername")
self.horizontalLayout_2.addWidget(self.parametername)
self.parameter_line = LineEdit(FitParameter)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth())
@ -44,20 +45,12 @@ class Ui_FitParameter(object):
self.parameter_line.setText("")
self.parameter_line.setObjectName("parameter_line")
self.horizontalLayout_2.addWidget(self.parameter_line)
spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
self.horizontalLayout_2.addItem(spacerItem)
self.fixed_check = QtWidgets.QCheckBox(FitParameter)
self.fixed_check.setObjectName("fixed_check")
self.horizontalLayout_2.addWidget(self.fixed_check)
self.global_checkbox = QtWidgets.QCheckBox(FitParameter)
self.global_checkbox.setObjectName("global_checkbox")
self.horizontalLayout_2.addWidget(self.global_checkbox)
self.toolButton = QtWidgets.QToolButton(FitParameter)
self.toolButton.setText("")
self.toolButton.setPopupMode(QtWidgets.QToolButton.InstantPopup)
self.toolButton.setArrowType(QtCore.Qt.RightArrow)
self.toolButton.setObjectName("toolButton")
self.horizontalLayout_2.addWidget(self.toolButton)
self.verticalLayout.addLayout(self.horizontalLayout_2)
self.frame = QtWidgets.QFrame(FitParameter)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum)

View File

@ -8,6 +8,7 @@ from pyqtgraph import mkPen
from nmreval.data.points import Points
from nmreval.data.signals import Signal
from nmreval.lib.logger import logger
from nmreval.utils.text import convert
from nmreval.data.bds import BDS
from nmreval.data.dsc import DSC
@ -356,7 +357,7 @@ class ExperimentContainer(QtCore.QObject):
elif mode in ['imag', 'all'] and self.plot_imag is not None:
self.plot_imag.set_symbol(symbol=symbol, size=size, color=color)
else:
print('Updating symbol failed for ' + str(self.id))
logger.warning(f'Updating symbol failed for {self.id}')
def setLine(self, *, width=None, style=None, color=None, mode='real'):
if mode in ['real', 'all']:
@ -368,7 +369,7 @@ class ExperimentContainer(QtCore.QObject):
elif mode in ['imag', 'all'] and self.plot_imag is not None:
self.plot_imag.set_line(width=width, style=style, color=color)
else:
print('Updating line failed for ' + str(self.id))
logger.warning(f'Updating line failed for {self.id}')
def update_property(self, key1: str, key2: str, value: Any):
keykey = key2.split()

View File

@ -1,3 +1,4 @@
from nmreval.lib.logger import logger
from nmreval.math import apodization
from nmreval.lib.importer import find_models
from nmreval.utils.text import convert
@ -67,7 +68,7 @@ class EditSignalWidget(QtWidgets.QWidget, Ui_Form):
self.do_something.emit(sender, (ph0, ph1, pvt))
else:
print('You should never reach this by accident.')
logger.warning(f'You should never reach this by accident, invalid sender {sender!r}')
@QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged')
def change_apodization(self, index):

View File

@ -19,19 +19,20 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
super().__init__(parent)
self.setupUi(self)
self.parametername.setText(label + ' ')
self.name = label
self.parametername.setText(convert(label) + ' ')
validator = QtGui.QDoubleValidator()
self.parameter_line.setValidator(validator)
self.parameter_line.setText('1')
self.parameter_line.setMaximumWidth(240)
self.lineEdit.setMaximumWidth(60)
self.lineEdit_2.setMaximumWidth(60)
self.parameter_line.setMaximumWidth(160)
self.lineEdit.setMaximumWidth(100)
self.lineEdit_2.setMaximumWidth(100)
self.label_3.setText(f'< {label} <')
self.label_3.setText(f'< {convert(label)} <')
self.checkBox.stateChanged.connect(self.enableBounds)
self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit())
self.parameter_line.editingFinished.connect(self.update_parameter)
self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self))
self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None))
self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text()))
@ -40,18 +41,12 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
if fixed:
self.fixed_check.hide()
self.menu = QtWidgets.QMenu(self)
self.add_links()
self.is_linked = None
self.parameter_pos = None
self.func_idx = None
self._linetext = '1'
@property
def name(self):
return convert(self.parametername.text().strip(), old='html', new='str')
self.menu = QtWidgets.QMenu(self)
def set_parameter_string(self, p: str):
self.parameter_line.setText(p)
@ -71,11 +66,6 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None,
fixed: bool = None, glob: bool = None):
if p is None:
# bad hack: linked parameter return (None, linked parameter)
# if p is None -> parameter is linked to argument given by bds
self.link_parameter(linkto=bds)
else:
ptext = f'{p:.4g}'
self.set_parameter_string(ptext)
@ -90,19 +80,10 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked)
def get_parameter(self):
if self.is_linked:
try:
p = float(self._linetext)
except ValueError:
p = 1.0
else:
try:
p = float(self.parameter_line.text().replace(',', '.'))
except ValueError:
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value',
f'{self.parametername.text()} contains invalid values',
QtWidgets.QMessageBox.Cancel)
return None
p = self.parameter_line.text().replace(',', '.')
if self.checkBox.isChecked():
try:
@ -119,75 +100,27 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
bounds = (lb, rb)
return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked(), self.is_linked
return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked()
@QtCore.pyqtSlot(bool)
def set_fixed(self, state: bool):
# self.global_checkbox.setVisible(not state)
self.frame.setVisible(not state)
def add_links(self, parameter: dict = None):
if parameter is None:
parameter = {}
self.menu.clear()
ac = QtWidgets.QAction('Link to...', self)
ac.triggered.connect(self.link_parameter)
self.menu.addAction(ac)
for model_key, model_funcs in parameter.items():
m = QtWidgets.QMenu('Model ' + model_key, self)
for func_name, func_params in model_funcs.items():
m2 = QtWidgets.QMenu(func_name, m)
for p_name, idx in func_params:
ac = QtWidgets.QAction(p_name, m2)
ac.setData((model_key, *idx))
ac.triggered.connect(self.link_parameter)
m2.addAction(ac)
m.addMenu(m2)
self.menu.addMenu(m)
self.toolButton.setMenu(self.menu)
@QtCore.pyqtSlot()
def link_parameter(self, linkto=None):
if linkto is None:
action = self.sender()
else:
action = False
for m in self.menu.actions():
if m.menu():
for a in m.menu().actions():
if a.data() == linkto:
action = a
break
if action:
break
if (self.func_idx, self.parameter_pos) == action.data():
return
def update_parameter(self):
new_value = self.parameter_line.text()
if not new_value:
self.parameter_line.setText('1')
try:
new_text = f'Linked to {action.parentWidget().title()}.{action.text()}'
self._linetext = self.parameter_line.text()
self.parameter_line.setText(new_text)
self.parameter_line.setEnabled(False)
self.global_checkbox.hide()
self.global_checkbox.blockSignals(True)
self.global_checkbox.setCheckState(QtCore.Qt.Checked)
self.global_checkbox.blockSignals(False)
self.frame.hide()
self.is_linked = action.data()
float(new_value)
is_text = False
except ValueError:
is_text = True
self.global_checkbox.setCheckState(False)
except AttributeError:
self.parameter_line.setText(self._linetext)
self.parameter_line.setEnabled(True)
if self.fixed_check.isEnabled():
self.global_checkbox.show()
self.frame.show()
self.is_linked = None
self.state_changed.emit()
self.set_fixed(is_text)
class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog):
@ -282,8 +215,17 @@ class FitModelTree(QtWidgets.QTreeWidget):
idx = item.data(0, self.counterRole)
self.itemRemoved.emit(idx)
def add_function(self, idx: int, cnt: int, op: int, name: str, color: QtGui.QColor | str | tuple,
parent: QtWidgets.QTreeWidgetItem = None, children: list = None, active: bool = True, **kwargs):
def add_function(self,
idx: int,
cnt: int,
op: int,
name: str,
color: QtGui.QColor | str | tuple,
parent: QtWidgets.QTreeWidgetItem = None,
children: list = None,
active: bool = True,
param_names: list[str] = None,
**kwargs):
"""
Add function to tree and dictionary of functions.
"""
@ -298,6 +240,10 @@ class FitModelTree(QtWidgets.QTreeWidget):
it.setData(0, self.counterRole, cnt)
it.setData(0, self.operatorRole, op)
it.setText(0, name)
if param_names is not None:
it.setToolTip(0,
'Parameter names:\n' +
'\n'.join(f'{pn}({cnt})' for pn in param_names))
it.setForeground(0, QtGui.QBrush(color))
it.setIcon(0, get_icon(self.icons[op]))

View File

@ -1,5 +1,8 @@
from __future__ import annotations
from typing import Optional
from nmreval.fit.parameter import Parameter
from nmreval.utils.text import convert
from ..Qt import QtWidgets, QtCore, QtGui
@ -62,8 +65,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.glob_values = [1] * len(func.params)
for k, v in enumerate(func.params):
name = convert(v)
widgt = FitModelWidget(label=name, parent=self.scrollwidget)
widgt = FitModelWidget(label=v, parent=self.scrollwidget)
widgt.parameter_pos = k
widgt.func_idx = idx
try:
@ -83,7 +85,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.global_parameter.append(widgt)
self.scrollwidget.layout().addWidget(widgt)
widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2)
widgt2 = ParameterSingleWidget(name=v, parent=self.scrollwidget2)
widgt2.valueChanged.connect(self.change_single_parameter)
widgt2.removeSingleValue.connect(self.change_single_parameter)
widgt2.installEventFilter(self)
@ -115,20 +117,22 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.scrollwidget.layout().addStretch(1)
self.scrollwidget2.layout().addStretch(1)
def set_links(self, parameter):
for w in self.global_parameter:
if isinstance(w, FitModelWidget):
w.add_links(parameter)
# def set_links(self, parameter):
# for w in self.global_parameter:
# if isinstance(w, FitModelWidget):
# w.add_links(parameter)
@QtCore.pyqtSlot(str)
def change_global_parameter(self, value: str, idx: int = None):
if idx is None:
idx = self.global_parameter.index(self.sender())
self.glob_values[idx] = float(value)
# self.glob_values[idx] = float(value)
self.glob_values[idx] = value
if self.data_values[self.comboBox.currentData()][idx] is None:
self.data_parameter[idx].blockSignals(True)
self.data_parameter[idx].value = float(value)
# self.data_parameter[idx].value = float(value)
self.data_parameter[idx].value = value
self.data_parameter[idx].blockSignals(False)
@QtCore.pyqtSlot(str, object)
@ -171,7 +175,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
# disable single parameter if it is set global, enable if global is unset
widget = self.sender()
idx = self.global_parameter.index(widget)
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None)
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked)
self.data_parameter[idx].setEnabled(enable)
def select_next_preview(self, direction):
@ -204,64 +208,50 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
if sid not in self.data_values:
self.data_values[sid] = [None] * len(self.data_parameter)
def get_parameter(self, use_func=None):
def get_parameter(self, use_func=None) -> tuple[dict, list]:
bds = []
is_global = []
is_fixed = []
globs = []
is_linked = []
param_general = []
for g in self.global_parameter:
if isinstance(g, FitModelWidget):
p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter()
p_i, bds_i, fixed_i, global_i = g.get_parameter()
parameter_i = Parameter(name=g.name, value=p_i, lb=bds_i[0], ub=bds_i[1], var=fixed_i)
param_general.append(parameter_i)
globs.append(p_i)
bds.append(bds_i)
is_fixed.append(fixed_i)
is_global.append(global_i)
is_linked.append(link_i)
lb, ub = list(zip(*bds))
data_parameter = {}
if use_func is None:
use_func = list(self.data_values.keys())
global_p = None
for sid, parameter in self.data_values.items():
if sid not in use_func:
continue
kw_p = {}
p = []
if global_p is None:
global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []}
for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)):
if isinstance(g, FitModelWidget):
if (p_i is None) or is_global[i]:
p.append(globs[i])
if is_global[i]:
if i not in global_p['idx']:
global_p['p'].append(globs[i])
global_p['idx'].append(i)
global_p['var'].append(is_fixed[i])
global_p['ub'].append(ub[i])
global_p['lb'].append(lb[i])
# set has no oen value
p.append(param_general[i].copy())
else:
p.append(p_i)
lb, ub = bds[i]
try:
if p[i] > ub[i]:
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
if not (lb < p_i < ub):
raise ValueError(f'Parameter {g.name} is outside bounds ({lb}, {ub})')
except TypeError:
pass
try:
if p[i] < lb[i]:
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
except TypeError:
pass
# create Parameter
p.append(
Parameter(name=g.name, value=p_i, lb=lb, ub=ub, var=is_fixed[i])
)
else:
if p_i is None:
@ -273,7 +263,15 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
data_parameter[sid] = (p, kw_p)
return data_parameter, lb, ub, is_fixed, global_p, is_linked
global_parameter = []
for param, global_flag in zip(param_general, is_global):
if global_flag:
param.is_global = True
global_parameter.append(param)
else:
global_parameter.append(None)
return data_parameter, global_parameter
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter))
@ -304,12 +302,12 @@ class ParameterSingleWidget(QtWidgets.QWidget):
self._init_ui()
self._name = name
self.name = name
self.label.setText(convert(name))
self.label.setToolTip('If this is bold then this parameter is only for this data. '
'Otherwise, the general parameter is used and displayed')
self.value_line.setValidator(QtGui.QDoubleValidator())
# self.value_line.setValidator(QtGui.QDoubleValidator())
self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0)
self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit())
@ -343,9 +341,10 @@ class ParameterSingleWidget(QtWidgets.QWidget):
@value.setter
def value(self, val):
self.value_line.setText(f'{float(val):.5g}')
# self.value_line.setText(f'{float(val):.5g}')
self.value_line.setText(f'{val}')
def show_as_local_parameter(self, is_local):
def show_as_local_parameter(self, is_local: bool):
if is_local:
self.label.setStyleSheet('font-weight: bold;')
else:

View File

@ -128,7 +128,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
self.newFunction.emit(idx, cnt)
self.add_function(idx, cnt, op, name, col)
self.add_function(idx, cnt, op, name, col, param_names=self.functions[idx].params)
def add_function(self, idx: int, cnt: int, op: int,
name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs):
@ -141,6 +141,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
qcolor = QtGui.QColor.fromRgbF(*color)
else:
qcolor = QtGui.QColor(color)
self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs)
f = self.functions[idx]

View File

@ -9,6 +9,9 @@ import numpy as np
from pyqtgraph import mkPen
from nmreval.fit._meta import MultiModel, ModelFactory
from nmreval.fit.data import Data
from nmreval.fit.model import Model
from nmreval.fit.parameter import Parameters
from nmreval.fit.result import FitResult
from .fit_forms import FitTableWidget
@ -116,7 +119,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
# collect parameter names etc. to allow linkage
self._func_list[self._current_model] = self.functionwidget.get_parameter_list()
dialog.set_links(self._func_list)
# dialog.set_links(self._func_list)
# show same tab (general parameter/Data parameter)
tab_idx = 0
@ -219,57 +222,49 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
def _prepare(self, model: list, function_use: list = None,
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
if parameter is None:
parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [],
'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []},
'links': [], 'color': []}
parameter = {
'data_parameter': {},
'global_parameter': [],
'links': [],
'color': [],
}
for i, f in enumerate(model):
if not f['active']:
continue
try:
p, lb, ub, var, glob, links = self.param_widgets[f['cnt']].get_parameter(function_use)
p, glob = self.param_widgets[f['cnt']].get_parameter(function_use)
except ValueError as e:
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e),
QtWidgets.QMessageBox.Ok)
return {}, -1
p_len = len(parameter['lb'])
parameter['lb'] += lb
parameter['ub'] += ub
parameter['var'] += var
parameter['links'] += links
parameter['color'] += [f['color']]
parameter['color'].append(f['color'])
parameter['global_parameter'].extend(glob)
cnt = f['cnt']
for p_k, v_k in p.items():
if add_idx:
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
else:
kw_k = v_k[1]
if p_k in parameter['parameter']:
params, kw = parameter['parameter'][p_k]
if p_k in parameter['data_parameter']:
params, kw = parameter['data_parameter'][p_k]
params += v_k[0]
kw.update(kw_k)
else:
parameter['parameter'][p_k] = (v_k[0], kw_k)
for g_k, g_v in glob.items():
if g_k != 'idx':
parameter['glob'][g_k] += g_v
else:
parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v]
parameter['data_parameter'][p_k] = (v_k[0], kw_k)
if add_idx:
cnt += 1
if f['children']:
# recurse for children
child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt)
_, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt)
return parameter, cnt
@ -280,30 +275,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
data = self.data_table.collect_data(default=self.default_combobox.currentData())
func_dict = {}
for k, mod in self.models.items():
func, order, param_len = ModelFactory.create_from_list(mod)
for model_name, model_parameter in self.models.items():
func, order, param_len = ModelFactory.create_from_list(model_parameter)
if func is None:
continue
if k in data:
parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel))
func = Model(func)
if model_name in data:
parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel))
if parameter is None:
return
for (data_parameter, _) in parameter['data_parameter'].values():
for pname, param in zip(func.params, data_parameter):
param.name = pname
if self._complex[model_name] is not None:
for p_k, p_v in parameter['data_parameter'].items():
p_v[1].update({'complex_mode': self._complex[model_name]})
parameter['data_parameter'][p_k] = p_v[0], p_v[1]
for pname, param_value in zip(func.params, parameter['global_parameter']):
if param_value is not None:
param_value.name = pname
func.set_global_parameter(param_value)
parameter['func'] = func
parameter['order'] = order
parameter['len'] = param_len
parameter['complex'] = self._complex[k]
if self._complex[k] is not None:
for p_k, p_v in parameter['parameter'].items():
p_v[1].update({'complex_mode': self._complex[k]})
parameter['parameter'][p_k] = p_v[0], p_v[1]
parameter['complex'] = self._complex[model_name]
func_dict[k] = parameter
func_dict[model_name] = parameter
replaceable = []
for k, v in func_dict.items():
for model_name, v in func_dict.items():
for i, link_i in enumerate(v['links']):
if link_i is None:
continue
@ -334,7 +342,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
QtWidgets.QMessageBox.Ok)
return
replaceable.append((k, i, rep_model, repl_idx))
replaceable.append((model_name, i, rep_model, repl_idx))
replace_value = None
for p_k in f['parameter'].values():
@ -412,31 +420,37 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
def make_previews(self, x, models_parameters: dict):
self.preview_lines = []
# needed to create namespace
param_dict = Parameters()
cnt = 0
for model in models_parameters.values():
f = model['func']
for parameter_list in model['data_parameter'].values():
for i, p_value in enumerate(parameter_list[0]):
p_value.name = f.params[i]
param_dict.add_parameter(f'a{cnt}', p_value)
cnt += 1
for k, model in models_parameters.items():
f = model['func']
is_complex = self._complex[k]
parameters = model['parameter']
parameters = model['data_parameter']
color = model['color']
seen_parameter = []
for p, kwargs in parameters.values():
if (p, kwargs) in seen_parameter:
# plot only previews with different parameter
continue
seen_parameter.append((p, kwargs))
p_value = [pp.value for pp in p]
if is_complex is not None:
y = f.func(x, *p, complex_mode=is_complex, **kwargs)
y = f.func(x, *p_value, complex_mode=is_complex, **kwargs)
if np.iscomplexobj(y):
self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3)))
self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3)))
else:
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
else:
y = f.func(x, *p, **kwargs)
y = f.func(x, *p_value, **kwargs)
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
if isinstance(f, MultiModel):
@ -444,7 +458,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
if is_complex is not None:
sub_kwargs.update({'complex_mode': is_complex})
for i, s in enumerate(f.subs(x, *p, **sub_kwargs)):
for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)):
pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i]))
if np.iscomplexobj(s):
self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i))
@ -452,15 +466,17 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
else:
self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i))
param_dict.clear()
return self.preview_lines
def set_parameter(self, parameter: dict[str, FitResult]):
# which data uses which model
data = self.data_table.collect_data(default=self.default_combobox.currentData())
for fitted_model, fitted_data in data.items():
glob_fit_parameter = []
for fitted_model, fitted_data in data.items():
for fit_id, fit_curve in parameter.items():
if fit_id in fitted_data:
fit_parameter = list(fit_curve.parameter.values())

View File

@ -138,9 +138,7 @@ class DrawingsWidget(QtWidgets.QWidget, Ui_Form):
graph_id = self.graph_comboBox.currentData()
current_lines = self.lines[graph_id]
print(remove_rows)
for i in reversed(remove_rows):
print(i)
self.tableWidget.removeRow(i)
self.line_deleted.emit(current_lines[i], graph_id)

View File

@ -27,7 +27,6 @@ class MdiAreaTile(QtWidgets.QMdiArea):
pos = QtCore.QPoint(0, 0)
for win in window_list:
print(win.minimumSize())
win.setGeometry(rect)
win.move(pos)

View File

@ -1,110 +0,0 @@
import os.path
import json
import urllib.request
import webbrowser
import random
from ..Qt import QtGui, QtCore, QtWidgets
from .._py.pokemon import Ui_Dialog
random.seed()
class QPokemon(QtWidgets.QDialog, Ui_Dialog):
def __init__(self, number=None, parent=None):
super().__init__(parent=parent)
self.setupUi(self)
self._js = json.load(open(os.path.join(path_to_module, 'utils', 'pokemon.json'), 'r'), encoding='UTF-8')
self._id = 0
if number is not None and number in range(1, len(self._js)+1):
poke_nr = f'{number:03d}'
self._id = number
else:
poke_nr = f'{random.randint(1, len(self._js)):03d}'
self._id = int(poke_nr)
self._pokemon = None
self.show_pokemon(poke_nr)
self.label_15.linkActivated.connect(lambda x: webbrowser.open(x))
self.buttonBox.clicked.connect(self.randomize)
self.next_button.clicked.connect(self.show_next)
self.prev_button.clicked.connect(self.show_prev)
def show_pokemon(self, nr):
self._pokemon = self._js[nr]
self.setWindowTitle('Pokémon: ' + self._pokemon['Deutsch'])
for i in range(self.tabWidget.count(), -1, -1):
print('i', self.tabWidget.count(), i)
try:
self.tabWidget.widget(i).deleteLater()
except AttributeError:
pass
for n, img in self._pokemon['Bilder']:
w = QtWidgets.QWidget()
vl = QtWidgets.QVBoxLayout()
l = QtWidgets.QLabel(self)
l.setAlignment(QtCore.Qt.AlignHCenter)
pixmap = QtGui.QPixmap()
try:
pixmap.loadFromData(urllib.request.urlopen(img, timeout=0.5).read())
except IOError:
l.setText(n)
else:
sc_pixmap = pixmap.scaled(256, 256, QtCore.Qt.KeepAspectRatio)
l.setPixmap(sc_pixmap)
vl.addWidget(l)
w.setLayout(vl)
self.tabWidget.addTab(w, n)
if len(self._pokemon['Bilder']) <= 1:
self.tabWidget.tabBar().setVisible(False)
else:
self.tabWidget.tabBar().setVisible(True)
self.tabWidget.adjustSize()
self.name.clear()
keys = ['National-Dex', 'Kategorie', 'Typ', 'Größe', 'Gewicht', 'Farbe', 'Link']
label_list = [self.pokedex_nr, self.category, self.poketype, self.weight, self.height, self.color, self.info]
for (k, label) in zip(keys, label_list):
v = self._pokemon[k]
if isinstance(v, list):
v = os.path.join('', *v)
if k == 'Link':
v = '<a href={}>{}</a>'.format(v, v)
label.setText(v)
for k in ['Deutsch', 'Japanisch', 'Englisch', 'Französisch']:
v = self._pokemon[k]
self.name.addItem(k + ': ' + v)
self.adjustSize()
def randomize(self, idd):
if idd.text() == 'Retry':
new_number = f'{random.randint(1, len(self._js)):03d}'
self._id = int(new_number)
self.show_pokemon(new_number)
else:
self.close()
def show_next(self):
new_number = self._id + 1
if new_number > len(self._js):
new_number = 1
self._id = new_number
self.show_pokemon(f'{new_number:03d}')
def show_prev(self):
new_number = self._id - 1
if new_number == 0:
new_number = len(self._js)
self._id = new_number
self.show_pokemon(f'{new_number:03d}')

View File

@ -441,7 +441,7 @@ class UpperManagement(QtCore.QObject):
# all-encompassing error catch
try:
for model_id, model_p in parameter.items():
m = Model(model_p['func'])
m = model_p['func']
models[model_id] = m
m_complex = model_p['complex']
@ -450,13 +450,16 @@ class UpperManagement(QtCore.QObject):
# iterate over order of set id in active order and access parameter inside loop
# instead of directly looping
try:
list_ids = list(model_p['parameter'].keys())
list_ids = list(model_p['data_parameter'].keys())
set_order = [self.active_id.index(i) for i in list_ids]
except ValueError as e:
raise Exception('Getting order failed') from e
for pos in set_order:
set_id = list_ids[pos]
data_i = self.data[set_id]
set_params = model_p['data_parameter'][set_id]
try:
data_i = self.data[set_id]
except KeyError as e:
@ -499,18 +502,12 @@ class UpperManagement(QtCore.QObject):
d.set_model(m)
try:
d.set_parameter(set_params[0], var=model_p['var'],
lb=model_p['lb'], ub=model_p['ub'],
fun_kwargs=set_params[1])
d.set_parameter(set_params[0], fun_kwargs=set_params[1])
except Exception as e:
raise Exception('Setting parameter failed') from e
self.fitter.add_data(d)
model_globs = model_p['glob']
if model_globs:
m.set_global_parameter(**model_p['glob'])
for links_i in links:
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
(models[links_i[2]], links_i[3]))
@ -1170,7 +1167,6 @@ class UpperManagement(QtCore.QObject):
@QtCore.pyqtSlot(dict)
def calc_relaxation(self, opts: dict):
params = opts['pts']
if len(params) == 4:
if params[3]:

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import numpy as np
from .model import Model
from .parameter import Parameters
from .parameter import Parameters, Parameter
class Data(object):
@ -16,7 +18,7 @@ class Data(object):
self.model = None
self.minimizer = None
self.parameter = Parameters()
self.para_keys = None
self.para_keys: list = []
self.fun_kwargs = {}
def __len__(self):
@ -68,12 +70,19 @@ class Data(object):
def get_model(self):
return self.model
def set_parameter(self, parameter, var=None, ub=None, lb=None,
default_bounds=False, fun_kwargs=None):
def set_parameter(self,
values: list[float | Parameter],
*,
var: list[bool] = None,
ub: list[float] = None,
lb: list[float] = None,
default_bounds: bool = False,
fun_kwargs: dict = None
):
"""
Creates parameter for this data.
If no Model is available, it falls back to the model
:param parameter: list of parameters
:param values: list of parameters
:param var: list of boolean or boolean; False fixes parameter at given list index.
Single value is broadcast to all parameter
:param ub: list of upper boundaries or float; Single value is broadcast to all parameter.
@ -87,23 +96,46 @@ class Data(object):
model = self.model
if model is None:
# Data has no unique
if self.minimizer is None:
model = None
else:
if self.minimizer is not None:
model = self.minimizer.fit_model
self.fun_kwargs.update(model.fun_kwargs)
if model is None:
raise ValueError('No model found, please set model before parameters')
if default_bounds:
if len(values) != len(model.params):
raise ValueError('Number of given parameter does not match number of model parameters')
is_parameter = [isinstance(v, Parameter) for v in values]
if all(is_parameter):
for p_i in values:
key = f"p{next(Parameters.parameter_counter)}"
self.parameter.add_parameter(key, p_i)
elif any(is_parameter):
raise ValueError('list of parameter are not all float of Parameter')
else:
if var is None:
var = [True] * len(values)
if lb is None:
if default_bounds:
lb = model.lb
else:
lb = [None] * len(values)
if ub is None:
if default_bounds:
ub = model.ub
else:
ub = [None] * len(values)
self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub)
arg_names = ['name', 'value', 'var', 'lb', 'ub']
for parameter_arg in zip(model.params, values, var, lb, ub):
self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)})
self.para_keys = list(self.parameter.keys())
self.fun_kwargs.update(model.fun_kwargs)
if fun_kwargs is not None:
self.fun_kwargs.update(fun_kwargs)
@ -123,6 +155,18 @@ class Data(object):
else:
return [p.value for p in self.minimizer.parameters[self.parameter]]
def replace_parameter(self, key: str, parameter: Parameter) -> None:
tobereplaced = None
for k, v in self.parameter.items():
if v.name == parameter.name:
tobereplaced = k
break
if tobereplaced is None:
raise KeyError(f'Global parameter {key} not found in list of parameters')
self.para_keys[self.para_keys.index(tobereplaced)] = key
self.parameter.replace_parameter(tobereplaced, key, parameter)
def cost(self, p):
"""
Cost function :math:`y-f(p, x)`

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import warnings
from itertools import product
@ -21,13 +23,70 @@ class FitAbortException(Exception):
pass
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_pars: list[list[str]]):
# replace values
for keys, values in zip(varpars, p):
for data_i in data:
if keys in data_i.parameter.keys():
# TODO move this to scaled_value setter
data_i.parameter[keys].scaled_value = values
data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [values.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.cost(actual_parameters)]
return r
def _cost_scipy(p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
actual_parameters = [data.parameter[keys].value for keys in used_pars]
return data.cost(actual_parameters)
def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0):
for keys, values in zip(varpars, p):
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
actual_parameters = [data.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]):
# replace values
for data_i in data:
_update_parameter(data_i, var_pars, p)
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [values.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]):
for keys, values in zip(varied_keys, parameter):
if keys in data.parameter.keys():
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
class FitRoutine(object):
def __init__(self, mode='lsq'):
self.fitmethod = mode
self.data = []
self.fit_model = None
self._no_own_model = []
self.parameter = Parameters()
self.result = []
self.linked = []
self._abort = False
@ -81,29 +140,27 @@ class FitRoutine(object):
return self.fit_model
def set_link_parameter(self, parameter: tuple, replacement: tuple):
def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]):
if isinstance(replacement[0], Model):
if replacement[1] not in replacement[0].global_parameter:
raise KeyError(f'Parameter at pos {replacement[1]} of '
f'model {str(replacement[0])} is not global')
if replacement[1] not in replacement[0].parameter:
raise KeyError(f'Parameter {replacement[1]} of '
f'model {replacement[0]} is not global')
if isinstance(parameter[0], Model):
warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} '
if isinstance(dismissed_param[0], Model):
warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} '
f'becomes global with linkage.')
self.linked.append((*parameter, *replacement))
self.linked.append((*dismissed_param, *replacement))
def prepare_links(self):
self._no_own_model = []
self.parameter = Parameters()
_found_models = {}
linked_sender = {}
for v in self.data:
linked_sender[v] = set()
self.parameter.update(v.parameter.copy())
# set temporaray model
# set temporary model
if v.model is None:
v.model = self.fit_model
self._no_own_model.append(v)
@ -111,8 +168,6 @@ class FitRoutine(object):
# register model
if v.model not in _found_models:
_found_models[v.model] = []
m_param = v.model.parameter.copy()
self.parameter.update(m_param)
_found_models[v.model].append(v)
@ -120,24 +175,21 @@ class FitRoutine(object):
linked_sender[v.model] = set()
linked_parameter = {}
for par, par_parm, repl, repl_par in self.linked:
if isinstance(par, Data):
if isinstance(repl, Data):
linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par]
else:
linked_parameter[par.para_keys[par_parm]] = repl.global_parameter[repl_par]
for dismiss_model, dismiss_param, replace_model, replace_param in self.linked:
linked_sender[replace_model].add(dismiss_model)
linked_sender[replace_model].add(replace_model)
else:
if isinstance(repl, Data):
par.global_parameter[par_parm] = repl.para_keys[repl_par]
else:
par.global_parameter[par_parm] = repl.global_parameter[repl_par]
replace_key = replace_model.parameter.get_key(replace_param)
dismiss_key = dismiss_model.parameter.get_key(dismiss_param)
linked_sender[repl].add(par)
linked_sender[par].add(repl)
if isinstance(replace_model, Data):
linked_parameter[dismiss_key] = replace_key
else:
p = dismiss_model.set_global_parameter(dismiss_param, replace_key)
p._expr_disp = replace_param
for mm, m_data in _found_models.items():
if mm.global_parameter:
if mm.parameter:
for dd in m_data:
linked_sender[mm].add(dd)
linked_sender[dd].add(mm)
@ -169,15 +221,13 @@ class FitRoutine(object):
logger.info('Fit aborted by user')
self._abort = True
def run(self, mode: str=None):
def run(self, mode: str = None):
self._abort = False
self.parameter = Parameters()
if mode is None:
mode = self.fitmethod
fit_groups, linked_parameter = self.prepare_links()
for data_groups in fit_groups:
if len(data_groups) == 1 and not self.linked:
data = data_groups[0]
@ -208,8 +258,21 @@ class FitRoutine(object):
self.unprep_run()
for r in self.result:
r.pprint()
return self.result
def make_preview(self, x: np.ndarray) -> list[np.ndarray]:
y_pred = []
fit_groups, linked_parameter = self.prepare_links()
for data_groups in fit_groups:
data = data_groups[0]
actual_parameters = [p.value for p in data.parameter.values()]
y_pred.append(data.func(actual_parameters, x))
return y_pred
def _prep_data(self, data):
if data.get_model() is None:
data._model = self.fit_model
@ -237,22 +300,16 @@ class FitRoutine(object):
var = []
data_pars = []
# loopyloop over data that belong to one fit (linked or global)
# loopy-loop over data that belong to one fit (linked or global)
for data in data_group:
actual_pars = []
for i, (p_k, v_k) in enumerate(data.parameter.items()):
p_k_used = p_k
v_k_used = v_k
# is parameter replaced by global parameter?
if i in data.model.global_parameter:
p_k_used = data.model.global_parameter[i]
v_k_used = self.parameter[p_k_used]
for k, v in data.model.parameter.items():
data.replace_parameter(k, v)
# links trump global parameter
if p_k_used in linked:
p_k_used = linked[p_k_used]
v_k_used = self.parameter[p_k_used]
actual_pars = []
for i, p_k in enumerate(data.para_keys):
p_k_used = p_k
v_k_used = data.parameter[p_k]
actual_pars.append(p_k_used)
# parameter is variable and was not found before as shared parameter
@ -271,48 +328,7 @@ class FitRoutine(object):
d._model = None
self._no_own_model = []
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
def __cost_scipy(self, p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
actual_parameters = [self.parameter[keys].value for keys in used_pars]
return data.cost(actual_parameters)
def __cost_odr(self, p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
actual_parameters = [self.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def __cost_scipy_glob(self, p, data, varpars, used_pars):
# replace values
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [self.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.cost(actual_parameters)]
return r
def __cost_odr_glob(self, p, data, varpars, used_pars):
# replace values
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [self.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
Parameters.reset()
def _least_squares_single(self, data, p0, lb, ub, var):
self.step = 0
@ -322,7 +338,7 @@ class FitRoutine(object):
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_scipy(p, data, var, data.para_keys)
return _cost_scipy(p, data, var, data.para_keys)
with np.errstate(all='ignore'):
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
@ -336,7 +352,7 @@ class FitRoutine(object):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_scipy_glob(p, data, var, data_pars)
return _cost_scipy_glob(p, data, var, data_pars)
with np.errstate(all='ignore'):
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
@ -351,7 +367,7 @@ class FitRoutine(object):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return (self.__cost_scipy(p, data, var, data.para_keys)**2).sum()
return (_cost_scipy(p, data, var, data.para_keys) ** 2).sum()
with np.errstate(all='ignore'):
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
@ -364,7 +380,7 @@ class FitRoutine(object):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return (self.__cost_scipy_glob(p, data, var, data_pars)**2).sum()
return (_cost_scipy_glob(p, data, var, data_pars) ** 2).sum()
with np.errstate(all='ignore'):
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
@ -380,13 +396,18 @@ class FitRoutine(object):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr(p, data, var_pars, data.para_keys)
return _cost_odr(p, data, var_pars, data.para_keys)
odr_model = odr.Model(func)
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
def _odr_fit(self, odr_data, odr_model, p0):
o = odr.ODR(odr_data, odr_model, beta0=p0)
res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try:
corr_inv = np.linalg.inv(corr)
@ -395,16 +416,14 @@ class FitRoutine(object):
partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError:
partial_corr = corr
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
return corr, partial_corr, res
def _odr_global(self, data, p0, var, data_pars):
def func(p, _):
self.step += 1
if self._abort:
raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr_glob(p, data, var, data_pars)
return _cost_odr_glob(p, data, var, data_pars)
x = []
y = []
@ -415,17 +434,7 @@ class FitRoutine(object):
odr_data = odr.Data(x, y)
odr_model = odr.Model(func)
o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var)
res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try:
corr_inv = np.linalg.inv(corr)
corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv)))
partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix
partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError:
partial_corr = corr
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
for v, var_pars_k in zip(data, data_pars):
self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)),
@ -439,15 +448,17 @@ class FitRoutine(object):
# update parameter values
for keys, p_value, err_value in zip(var_pars, p, err):
self.parameter[keys].scaled_value = p_value
self.parameter[keys].scaled_error = err_value
if keys in data.parameter.keys():
data.parameter[keys].scaled_value = p_value
data.parameter[keys].scaled_error = err_value
data.parameter[keys].namespace[keys] = data.parameter[keys].value
combinations = list(product(var_pars, var_pars))
actual_parameters = []
corr_idx = []
for i, p_i in enumerate(used_pars):
actual_parameters.append(self.parameter[p_i])
actual_parameters.append(data.parameter[p_i])
for j, p_j in enumerate(used_pars):
try:
# find the position of the parameter combinations
@ -508,3 +519,4 @@ class FitRoutine(object):
partial_corr = corr
return _err, corr, partial_corr

View File

@ -6,7 +6,7 @@ from typing import Sized
from numpy import inf
from ._meta import MultiModel
from .parameter import Parameters
from .parameter import Parameters, Parameter
class Model(object):
@ -25,7 +25,6 @@ class Model(object):
self.ub = [i if i is not None else inf for i in self.ub]
self.parameter = Parameters()
self.global_parameter = {}
self.is_complex = None
self._complex_part = False
@ -80,23 +79,33 @@ class Model(object):
self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items()
if v.default is not inspect.Parameter.empty}
def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False):
if idx is None:
self.parameter = Parameters()
self.global_parameter = {}
return
def set_global_parameter(self,
key: str | Parameter,
value: float | str = None,
*,
var: bool = None,
lb: float = None,
ub: float = None,
default_bounds: bool = False,
) -> Parameter:
if isinstance(key, Parameter):
p = key
key = f'p{next(Parameters.parameter_counter)}'
self.parameter.add_parameter(key, p)
else:
idx = [self.params.index(key)]
if default_bounds:
if lb is None:
lb = [self.lb[i] for i in idx]
if ub is None:
ub = [self.lb[i] for i in idx]
gp = self.parameter.add_parameter(p, var=var, lb=lb, ub=ub)
for k, v in zip(idx, gp):
self.global_parameter[k] = v
p = self.parameter.add(key, value, var=var, lb=lb, ub=ub)
p.is_global = True
return gp
return p
@staticmethod
def _prep(param_len, val):

View File

@ -1,94 +1,134 @@
from __future__ import annotations
from numbers import Number
import re
from itertools import count
from io import StringIO
import numpy as np
class Parameters(dict):
count = count()
parameter_counter = count()
# is one global namespace a good idea?
namespace: dict = {}
def __str__(self):
return 'Parameters:\n' + '\n'.join([str(k)+': '+str(v) for k, v in self.items()])
def __init__(self):
super().__init__()
self._mapping: dict = {}
def __getitem__(self, item):
if isinstance(item, (list, tuple, np.ndarray)):
values = []
for item_i in item:
values.append(super().__getitem__(item_i))
return values
def __str__(self) -> str:
return 'Parameters:\n' + '\n'.join([f'{k}: {v}' for k, v in self.items()])
def __getitem__(self, item) -> Parameter:
if item in self._mapping:
return super().__getitem__(self._mapping[item])
else:
return super().__getitem__(item)
def __setitem__(self, key, value):
self.add_parameter(key, value)
def __contains__(self, item):
for v in self.values():
if item == v.name:
return True
return False
def add(self,
name: str,
value: str | float | int = None,
*,
var: bool = True,
lb: float = -np.inf, ub: float = np.inf) -> Parameter:
par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub)
key = f'p{next(Parameters.parameter_counter)}'
self.add_parameter(key, par)
return par
def add_parameter(self, key: str, parameter: Parameter):
self._mapping[parameter.name] = key
super().__setitem__(key, parameter)
parameter.eval_allowed = False
self.namespace[key] = parameter.value
parameter.namespace = self.namespace
parameter.eval_allowed = True
self.update_namespace()
def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter):
self.add_parameter(key_in, parameter)
for k, v in self._mapping.items():
if v == key_out:
self._mapping[k] = key_in
break
if key_out in self.namespace:
del self.namespace[key_out]
def fix(self):
for v in self.keys():
v._value = v.value
v.namespace = {}
@staticmethod
def _prep_bounds(val, p_len: int) -> list:
# helper function to ensure that bounds and variable are of parameter shape
if isinstance(val, (Number, bool)) or val is None:
return [val] * p_len
def reset():
Parameters.namespace = {}
elif len(val) == p_len:
return val
elif len(val) == 1:
return [val[0]] * p_len
else:
raise ValueError('Input {} has wrong dimensions'.format(val))
def add_parameter(self, param, var=None, lb=None, ub=None):
if isinstance(param, Number):
param = [param]
p_len = len(param)
# make list if only single value is given
var = self._prep_bounds(var, p_len)
lb = self._prep_bounds(lb, p_len)
ub = self._prep_bounds(ub, p_len)
new_keys = []
for i in range(p_len):
new_idx = next(self.count)
new_keys.append(new_idx)
self[new_idx] = Parameter(param[i], var=var[i], lb=lb[i], ub=ub[i])
return new_keys
def copy(self):
p = Parameters()
def get_key(self, name: str) -> str | None:
for k, v in self.items():
p[k] = Parameter(v.value, var=v.var, lb=v.lb, ub=v.ub)
if name == v.name:
return k
if len(p) == 0:
return p
max_k = max(p.keys())
c = next(p.count)
while c < max_k:
c = next(p.count)
return p
return
def get_state(self):
return {k: v.get_state() for k, v in self.items()}
def update_namespace(self):
for p in self.values():
try:
p.value
except NameError:
expression = p._expr_disp
for n, k in self._mapping.items():
expression, num_replaced = re.subn(re.escape(n), k, expression)
p._expr = expression
class Parameter:
"""
Container for one parameter
"""
__slots__ = ['name', 'value', 'error', 'init_val', 'var', 'lb', 'ub', 'scale', 'function']
def __init__(self, value: float, var: bool = True, lb: float = -np.inf, ub: float = np.inf):
self.lb = lb if lb is not None else -np.inf
self.ub = ub if ub is not None else np.inf
# TODO Parameter should know its own key
def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf):
self._value: float | None = None
self.var: bool = bool(var) if var is not None else True
self.error: None | float = None if self.var is False else 0.0
self.name: str = name
self.function: str = ""
if self.lb <= value <= self.ub:
self.value = value
self.lb: None | float = lb if lb is not None else -np.inf
self.ub: float | None = ub if ub is not None else np.inf
self.namespace: dict = {}
self.eval_allowed: bool = True
self._expr: None | str = None
self._expr_disp: None | str = None
self.is_global = False
if isinstance(value, str):
self._expr_disp = value
self._expr = value
self.var = False
else:
if self.lb <= value <= self.ub:
self._value = value
else:
print(value, self.lb, self.ub)
raise ValueError('Value of parameter is outside bounds')
self.init_val = value
@ -100,25 +140,31 @@ class Parameter:
if self.scale == 0:
self.scale = 1.
self.var = bool(var) if var is not None else True
self.error = None if self.var is False else 0.0
self.name = ''
self.function = ''
def __str__(self):
start = ''
def __str__(self) -> str:
start = StringIO()
if self.name:
if self.function:
start = f'{self.name} ({self.function}): '
start.write(f"{self.name} ({self.function})")
else:
start = self.name + ': '
start.write(self.name)
if self.is_global:
start.write("*")
start.write(": ")
if self.var:
return start + f'{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}'
start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}")
else:
return start + f'{self.value:} (fixed)'
start.write(f"{self.value:.4g}")
if self._expr is None:
start.write(" (fixed)")
else:
start.write(f" (calc: {self._expr_disp})")
def __add__(self, other: Parameter | float) -> float:
return start.getvalue()
def __add__(self, other: Parameter | float | int) -> float:
if isinstance(other, (float, int)):
return self.value + other
elif isinstance(other, Parameter):
@ -128,30 +174,39 @@ class Parameter:
return self.__add__(other)
@property
def scaled_value(self):
def scaled_value(self) -> float:
return self.value / self.scale
@scaled_value.setter
def scaled_value(self, value):
self.value = value * self.scale
def scaled_value(self, value: float) -> None:
self._value = value * self.scale
@property
def scaled_error(self):
if self.error is None:
return self.error
else:
def value(self) -> float | None:
if self._value is not None:
return self._value
if self._expr is not None and self.eval_allowed:
return eval(self._expr, {}, self.namespace)
return
@property
def scaled_error(self) -> None | float:
if self.error is not None:
return self.error / self.scale
return
@scaled_error.setter
def scaled_error(self, value):
def scaled_error(self, value) -> None:
self.error = value * self.scale
def get_state(self):
def get_state(self) -> dict:
return {slot: getattr(self, slot) for slot in self.__slots__}
@staticmethod
def set_state(state: dict):
def set_state(state: dict) -> Parameter:
par = Parameter(state.pop('value'))
for k, v in state.items():
setattr(par, k, v)
@ -159,9 +214,28 @@ class Parameter:
return par
@property
def full_name(self):
def full_name(self) -> str:
name = self.name
if self.function:
name += ' (' + self.function + ')'
name += f" ({self.function})"
return name
def copy(self) -> Parameter:
if self._expr:
val = self._expr_disp
else:
val = self._value
para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub)
para_copy._expr = self._expr
para_copy.namespace = self.namespace
para_copy.is_global = self.is_global
para_copy.error = self.error
para_copy.function = self.function
return para_copy
def fix(self):
self._value = self.value
self.namespace = {}

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import re
from collections import OrderedDict
from io import StringIO
from pathlib import Path
from typing import Any
@ -186,7 +187,7 @@ class FitResult(Points):
nice_name = m.group(1)
if func_number in split_funcs:
nice_func = split_funcs[func_number]
pvalue.fix()
pvalue.name = nice_name
pvalue.function = nice_func
parameter_dic[pname] = pvalue
@ -223,27 +224,30 @@ class FitResult(Points):
return self.nobs-self.nvar
def pprint(self, statistics=True, correlations=True):
print('Fit result:')
print(' model :', self.name)
print(' #data :', self.nobs)
print(' #var :', self.nvar)
print('\nParameter')
print(self.parameter_string())
s = StringIO()
s.write('Fit result:\n')
s.write(f' model : {self.name}\n')
s.write(f' #data : {self.nobs}\n')
s.write(f' #var : {self.nvar}\n')
s.write('\nParameter\n')
s.write(self.parameter_string())
if statistics:
print('Statistics')
s.write('\nStatistics\n')
for k, v in self.statistics.items():
print(f' {k} : {v:.4f}')
s.write(f' {k} : {v:.4f}\n')
if correlations and self.correlation is not None:
print('\nCorrelation (partial corr.)')
print(self._correlation_string())
print()
s.write('\nCorrelation (partial corr.)\n')
s.write(self._correlation_string())
s.write('\n')
print(s.getvalue())
def parameter_string(self):
ret_val = ''
for pval in self.parameter.values():
for pkey, pval in self.parameter.items():
ret_val += convert(str(pval), old='tex', new='str') + '\n'
if self.fun_kwargs:
@ -255,9 +259,7 @@ class FitResult(Points):
def _correlation_string(self):
ret_val = ''
for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list():
ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'),
convert(p_j, old='tex', new='str'),
corr_ij, pcorr_ij)
ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n"
return ret_val
def correlation_list(self, limit=0.1):

View File

@ -35,8 +35,8 @@ class Gaussian:
class Lorentzian:
type = 'Spectrum'
name = 'Lorentzian'
equation = 'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}'
params = ['A', '\mu', 'w', 'A_{0}']
equation = r'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}'
params = ['A', r'\mu', 'w', 'A_{0}']
ext_params = None
bounds = [(0, None), (None, None), (0, None), (None, None)]
@ -62,9 +62,9 @@ class Lorentzian:
class PseudoVoigt:
type = 'Spectrum'
name = 'Pseudo Voigt'
equation = 'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \
'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}'
params = ['A', 'R', '\mu', 'w', 'A_{0}']
equation = r'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \
r'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}'
params = ['A', 'R', r'\mu', 'w', 'A_{0}']
ext_params = None
bounds = [(0, None), (0, 1), (None, None), (0, None)]

View File

@ -7,7 +7,7 @@
<x>0</x>
<y>0</y>
<width>365</width>
<height>78</height>
<height>66</height>
</rect>
</property>
<property name="sizePolicy">
@ -62,7 +62,7 @@
<item>
<widget class="LineEdit" name="parameter_line">
<property name="sizePolicy">
<sizepolicy hsizetype="Fixed" vsizetype="Fixed">
<sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
@ -78,19 +78,6 @@
</property>
</widget>
</item>
<item>
<spacer name="horizontalSpacer">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>40</width>
<height>20</height>
</size>
</property>
</spacer>
</item>
<item>
<widget class="QCheckBox" name="fixed_check">
<property name="text">
@ -105,19 +92,6 @@
</property>
</widget>
</item>
<item>
<widget class="QToolButton" name="toolButton">
<property name="text">
<string/>
</property>
<property name="popupMode">
<enum>QToolButton::InstantPopup</enum>
</property>
<property name="arrowType">
<enum>Qt::RightArrow</enum>
</property>
</widget>
</item>
</layout>
</item>
<item>