Merge branch 'fit_constraints'
# Conflicts: # src/gui_qt/main/management.py
This commit is contained in:
		| @@ -1,10 +1,11 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| # Form implementation generated from reading ui file 'resources/_ui/fitmodelwidget.ui' | ||||
| # Form implementation generated from reading ui file 'src/resources/_ui/fitmodelwidget.ui' | ||||
| # | ||||
| # Created by: PyQt5 UI code generator 5.12.3 | ||||
| # Created by: PyQt5 UI code generator 5.15.9 | ||||
| # | ||||
| # WARNING! All changes made in this file will be lost! | ||||
| # WARNING: Any manual changes made to this file will be lost when pyuic5 is | ||||
| # run again.  Do not edit this file unless you know what you are doing. | ||||
|  | ||||
|  | ||||
| from PyQt5 import QtCore, QtGui, QtWidgets | ||||
| @@ -13,7 +14,7 @@ from PyQt5 import QtCore, QtGui, QtWidgets | ||||
| class Ui_FitParameter(object): | ||||
|     def setupUi(self, FitParameter): | ||||
|         FitParameter.setObjectName("FitParameter") | ||||
|         FitParameter.resize(365, 78) | ||||
|         FitParameter.resize(365, 66) | ||||
|         sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding) | ||||
|         sizePolicy.setHorizontalStretch(0) | ||||
|         sizePolicy.setVerticalStretch(0) | ||||
| @@ -36,7 +37,7 @@ class Ui_FitParameter(object): | ||||
|         self.parametername.setObjectName("parametername") | ||||
|         self.horizontalLayout_2.addWidget(self.parametername) | ||||
|         self.parameter_line = LineEdit(FitParameter) | ||||
|         sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) | ||||
|         sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) | ||||
|         sizePolicy.setHorizontalStretch(0) | ||||
|         sizePolicy.setVerticalStretch(0) | ||||
|         sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth()) | ||||
| @@ -44,20 +45,12 @@ class Ui_FitParameter(object): | ||||
|         self.parameter_line.setText("") | ||||
|         self.parameter_line.setObjectName("parameter_line") | ||||
|         self.horizontalLayout_2.addWidget(self.parameter_line) | ||||
|         spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) | ||||
|         self.horizontalLayout_2.addItem(spacerItem) | ||||
|         self.fixed_check = QtWidgets.QCheckBox(FitParameter) | ||||
|         self.fixed_check.setObjectName("fixed_check") | ||||
|         self.horizontalLayout_2.addWidget(self.fixed_check) | ||||
|         self.global_checkbox = QtWidgets.QCheckBox(FitParameter) | ||||
|         self.global_checkbox.setObjectName("global_checkbox") | ||||
|         self.horizontalLayout_2.addWidget(self.global_checkbox) | ||||
|         self.toolButton = QtWidgets.QToolButton(FitParameter) | ||||
|         self.toolButton.setText("") | ||||
|         self.toolButton.setPopupMode(QtWidgets.QToolButton.InstantPopup) | ||||
|         self.toolButton.setArrowType(QtCore.Qt.RightArrow) | ||||
|         self.toolButton.setObjectName("toolButton") | ||||
|         self.horizontalLayout_2.addWidget(self.toolButton) | ||||
|         self.verticalLayout.addLayout(self.horizontalLayout_2) | ||||
|         self.frame = QtWidgets.QFrame(FitParameter) | ||||
|         sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum) | ||||
|   | ||||
| @@ -8,6 +8,7 @@ from pyqtgraph import mkPen | ||||
|  | ||||
| from nmreval.data.points import Points | ||||
| from nmreval.data.signals import Signal | ||||
| from nmreval.lib.logger import logger | ||||
| from nmreval.utils.text import convert | ||||
| from nmreval.data.bds import BDS | ||||
| from nmreval.data.dsc import DSC | ||||
| @@ -356,7 +357,7 @@ class ExperimentContainer(QtCore.QObject): | ||||
|         elif mode in ['imag', 'all'] and self.plot_imag is not None: | ||||
|             self.plot_imag.set_symbol(symbol=symbol, size=size, color=color) | ||||
|         else: | ||||
|             print('Updating symbol failed for ' + str(self.id)) | ||||
|             logger.warning(f'Updating symbol failed for {self.id}') | ||||
|  | ||||
|     def setLine(self, *, width=None, style=None, color=None, mode='real'): | ||||
|         if mode in ['real', 'all']: | ||||
| @@ -368,7 +369,7 @@ class ExperimentContainer(QtCore.QObject): | ||||
|         elif mode in ['imag', 'all'] and self.plot_imag is not None: | ||||
|             self.plot_imag.set_line(width=width, style=style, color=color) | ||||
|         else: | ||||
|             print('Updating line failed for ' + str(self.id)) | ||||
|             logger.warning(f'Updating line failed for {self.id}') | ||||
|  | ||||
|     def update_property(self, key1: str, key2: str, value: Any): | ||||
|         keykey = key2.split() | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| from nmreval.lib.logger import logger | ||||
| from nmreval.math import apodization | ||||
| from nmreval.lib.importer import find_models | ||||
| from nmreval.utils.text import convert | ||||
| @@ -67,7 +68,7 @@ class EditSignalWidget(QtWidgets.QWidget, Ui_Form): | ||||
|                 self.do_something.emit(sender, (ph0, ph1, pvt)) | ||||
|  | ||||
|         else: | ||||
|             print('You should never reach this by accident.') | ||||
|             logger.warning(f'You should never reach this by accident, invalid sender {sender!r}') | ||||
|  | ||||
|     @QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged') | ||||
|     def change_apodization(self, index): | ||||
|   | ||||
| @@ -19,19 +19,20 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): | ||||
|         super().__init__(parent) | ||||
|         self.setupUi(self) | ||||
|  | ||||
|         self.parametername.setText(label + ' ') | ||||
|         self.name = label | ||||
|  | ||||
|         self.parametername.setText(convert(label) + ' ') | ||||
|  | ||||
|         validator = QtGui.QDoubleValidator() | ||||
|         self.parameter_line.setValidator(validator) | ||||
|         self.parameter_line.setText('1') | ||||
|         self.parameter_line.setMaximumWidth(240) | ||||
|         self.lineEdit.setMaximumWidth(60) | ||||
|         self.lineEdit_2.setMaximumWidth(60) | ||||
|         self.parameter_line.setMaximumWidth(160) | ||||
|         self.lineEdit.setMaximumWidth(100) | ||||
|         self.lineEdit_2.setMaximumWidth(100) | ||||
|  | ||||
|         self.label_3.setText(f'< {label} <') | ||||
|         self.label_3.setText(f'< {convert(label)} <') | ||||
|  | ||||
|         self.checkBox.stateChanged.connect(self.enableBounds) | ||||
|         self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit()) | ||||
|         self.parameter_line.editingFinished.connect(self.update_parameter) | ||||
|         self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self)) | ||||
|         self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None)) | ||||
|         self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text())) | ||||
| @@ -40,18 +41,12 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): | ||||
|         if fixed: | ||||
|             self.fixed_check.hide() | ||||
|  | ||||
|         self.menu = QtWidgets.QMenu(self) | ||||
|         self.add_links() | ||||
|  | ||||
|         self.is_linked = None | ||||
|         self.parameter_pos = None | ||||
|         self.func_idx = None | ||||
|  | ||||
|         self._linetext = '1' | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return convert(self.parametername.text().strip(), old='html', new='str') | ||||
|         self.menu = QtWidgets.QMenu(self) | ||||
|  | ||||
|     def set_parameter_string(self, p: str): | ||||
|         self.parameter_line.setText(p) | ||||
| @@ -71,38 +66,24 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): | ||||
|  | ||||
|     def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None, | ||||
|                       fixed: bool = None, glob: bool = None): | ||||
|         if p is None: | ||||
|             # bad hack: linked parameter return (None, linked parameter) | ||||
|             # if p is None -> parameter is linked to argument given by bds | ||||
|             self.link_parameter(linkto=bds) | ||||
|         else: | ||||
|             ptext = f'{p:.4g}' | ||||
|         ptext = f'{p:.4g}' | ||||
|  | ||||
|             self.set_parameter_string(ptext) | ||||
|         self.set_parameter_string(ptext) | ||||
|  | ||||
|             if bds is not None: | ||||
|                 self.set_bounds(*bds) | ||||
|         if bds is not None: | ||||
|             self.set_bounds(*bds) | ||||
|  | ||||
|             if fixed is not None: | ||||
|                 self.fixed_check.setCheckState(QtCore.Qt.Unchecked if fixed else QtCore.Qt.Checked) | ||||
|         if fixed is not None: | ||||
|             self.fixed_check.setCheckState(QtCore.Qt.Unchecked if fixed else QtCore.Qt.Checked) | ||||
|  | ||||
|             if glob is not None: | ||||
|                 self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked) | ||||
|         if glob is not None: | ||||
|             self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked) | ||||
|  | ||||
|     def get_parameter(self): | ||||
|         if self.is_linked: | ||||
|             try: | ||||
|                 p = float(self._linetext) | ||||
|             except ValueError: | ||||
|                 p = 1.0 | ||||
|         else: | ||||
|             try: | ||||
|                 p = float(self.parameter_line.text().replace(',', '.')) | ||||
|             except ValueError: | ||||
|                 _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', | ||||
|                                                     f'{self.parametername.text()} contains invalid values', | ||||
|                                                     QtWidgets.QMessageBox.Cancel) | ||||
|                 return None | ||||
|         try: | ||||
|             p = float(self.parameter_line.text().replace(',', '.')) | ||||
|         except ValueError: | ||||
|             p = self.parameter_line.text().replace(',', '.') | ||||
|  | ||||
|         if self.checkBox.isChecked(): | ||||
|             try: | ||||
| @@ -119,75 +100,27 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): | ||||
|  | ||||
|         bounds = (lb, rb) | ||||
|  | ||||
|         return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked(), self.is_linked | ||||
|         return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked() | ||||
|  | ||||
|     @QtCore.pyqtSlot(bool) | ||||
|     def set_fixed(self, state: bool): | ||||
|         # self.global_checkbox.setVisible(not state) | ||||
|         self.frame.setVisible(not state) | ||||
|  | ||||
|     def add_links(self, parameter: dict = None): | ||||
|         if parameter is None: | ||||
|             parameter = {} | ||||
|         self.menu.clear() | ||||
|  | ||||
|         ac = QtWidgets.QAction('Link to...', self) | ||||
|         ac.triggered.connect(self.link_parameter) | ||||
|         self.menu.addAction(ac) | ||||
|  | ||||
|         for model_key, model_funcs in parameter.items(): | ||||
|             m = QtWidgets.QMenu('Model ' + model_key, self) | ||||
|             for func_name, func_params in model_funcs.items(): | ||||
|                 m2 = QtWidgets.QMenu(func_name, m) | ||||
|                 for p_name, idx in func_params: | ||||
|                     ac = QtWidgets.QAction(p_name, m2) | ||||
|                     ac.setData((model_key, *idx)) | ||||
|                     ac.triggered.connect(self.link_parameter) | ||||
|                     m2.addAction(ac) | ||||
|                 m.addMenu(m2) | ||||
|             self.menu.addMenu(m) | ||||
|  | ||||
|         self.toolButton.setMenu(self.menu) | ||||
|  | ||||
|     @QtCore.pyqtSlot() | ||||
|     def link_parameter(self, linkto=None): | ||||
|         if linkto is None: | ||||
|             action = self.sender() | ||||
|         else: | ||||
|             action = False | ||||
|             for m in self.menu.actions(): | ||||
|                 if m.menu(): | ||||
|                     for a in m.menu().actions(): | ||||
|                         if a.data() == linkto: | ||||
|                             action = a | ||||
|                             break | ||||
|                 if action: | ||||
|                     break | ||||
|  | ||||
|         if (self.func_idx, self.parameter_pos) == action.data(): | ||||
|             return | ||||
|     def update_parameter(self): | ||||
|         new_value = self.parameter_line.text() | ||||
|         if not new_value: | ||||
|             self.parameter_line.setText('1') | ||||
|  | ||||
|         try: | ||||
|             new_text = f'Linked to {action.parentWidget().title()}.{action.text()}' | ||||
|             self._linetext = self.parameter_line.text() | ||||
|             self.parameter_line.setText(new_text) | ||||
|             self.parameter_line.setEnabled(False) | ||||
|             self.global_checkbox.hide() | ||||
|             self.global_checkbox.blockSignals(True) | ||||
|             self.global_checkbox.setCheckState(QtCore.Qt.Checked) | ||||
|             self.global_checkbox.blockSignals(False) | ||||
|             self.frame.hide() | ||||
|             self.is_linked = action.data() | ||||
|             float(new_value) | ||||
|             is_text = False | ||||
|         except ValueError: | ||||
|             is_text = True | ||||
|             self.global_checkbox.setCheckState(False) | ||||
|  | ||||
|         except AttributeError: | ||||
|             self.parameter_line.setText(self._linetext) | ||||
|             self.parameter_line.setEnabled(True) | ||||
|             if self.fixed_check.isEnabled(): | ||||
|                 self.global_checkbox.show() | ||||
|                 self.frame.show() | ||||
|             self.is_linked = None | ||||
|  | ||||
|         self.state_changed.emit() | ||||
|         self.set_fixed(is_text) | ||||
|  | ||||
|  | ||||
| class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog): | ||||
| @@ -282,8 +215,17 @@ class FitModelTree(QtWidgets.QTreeWidget): | ||||
|         idx = item.data(0, self.counterRole) | ||||
|         self.itemRemoved.emit(idx) | ||||
|  | ||||
|     def add_function(self, idx: int, cnt: int, op: int, name: str, color: QtGui.QColor | str | tuple, | ||||
|                      parent: QtWidgets.QTreeWidgetItem = None, children: list = None, active: bool = True, **kwargs): | ||||
|     def add_function(self, | ||||
|                      idx: int, | ||||
|                      cnt: int, | ||||
|                      op: int, | ||||
|                      name: str, | ||||
|                      color: QtGui.QColor | str | tuple, | ||||
|                      parent: QtWidgets.QTreeWidgetItem = None, | ||||
|                      children: list = None, | ||||
|                      active: bool = True, | ||||
|                      param_names: list[str] = None, | ||||
|                      **kwargs): | ||||
|         """ | ||||
|         Add function to tree and dictionary of functions. | ||||
|         """ | ||||
| @@ -298,6 +240,10 @@ class FitModelTree(QtWidgets.QTreeWidget): | ||||
|         it.setData(0, self.counterRole, cnt) | ||||
|         it.setData(0, self.operatorRole, op) | ||||
|         it.setText(0, name) | ||||
|         if param_names is not None: | ||||
|             it.setToolTip(0, | ||||
|                           'Parameter names:\n' + | ||||
|                           '\n'.join(f'{pn}({cnt})' for pn in param_names)) | ||||
|         it.setForeground(0, QtGui.QBrush(color)) | ||||
|  | ||||
|         it.setIcon(0, get_icon(self.icons[op])) | ||||
|   | ||||
| @@ -1,5 +1,8 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from nmreval.fit.parameter import Parameter | ||||
| from nmreval.utils.text import convert | ||||
|  | ||||
| from ..Qt import QtWidgets, QtCore, QtGui | ||||
| @@ -62,8 +65,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|         self.glob_values = [1] * len(func.params) | ||||
|  | ||||
|         for k, v in enumerate(func.params): | ||||
|             name = convert(v) | ||||
|             widgt = FitModelWidget(label=name, parent=self.scrollwidget) | ||||
|             widgt = FitModelWidget(label=v, parent=self.scrollwidget) | ||||
|             widgt.parameter_pos = k | ||||
|             widgt.func_idx = idx | ||||
|             try: | ||||
| @@ -83,7 +85,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|             self.global_parameter.append(widgt) | ||||
|             self.scrollwidget.layout().addWidget(widgt) | ||||
|  | ||||
|             widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2) | ||||
|             widgt2 = ParameterSingleWidget(name=v, parent=self.scrollwidget2) | ||||
|             widgt2.valueChanged.connect(self.change_single_parameter) | ||||
|             widgt2.removeSingleValue.connect(self.change_single_parameter) | ||||
|             widgt2.installEventFilter(self) | ||||
| @@ -115,20 +117,22 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|         self.scrollwidget.layout().addStretch(1) | ||||
|         self.scrollwidget2.layout().addStretch(1) | ||||
|  | ||||
|     def set_links(self, parameter): | ||||
|         for w in self.global_parameter: | ||||
|             if isinstance(w, FitModelWidget): | ||||
|                 w.add_links(parameter) | ||||
|     # def set_links(self, parameter): | ||||
|     #     for w in self.global_parameter: | ||||
|     #         if isinstance(w, FitModelWidget): | ||||
|     #             w.add_links(parameter) | ||||
|  | ||||
|     @QtCore.pyqtSlot(str) | ||||
|     def change_global_parameter(self, value: str, idx: int = None): | ||||
|         if idx is None: | ||||
|             idx = self.global_parameter.index(self.sender()) | ||||
|  | ||||
|         self.glob_values[idx] = float(value) | ||||
|         # self.glob_values[idx] = float(value) | ||||
|         self.glob_values[idx] = value | ||||
|         if self.data_values[self.comboBox.currentData()][idx] is None: | ||||
|             self.data_parameter[idx].blockSignals(True) | ||||
|             self.data_parameter[idx].value = float(value) | ||||
|             # self.data_parameter[idx].value = float(value) | ||||
|             self.data_parameter[idx].value = value | ||||
|             self.data_parameter[idx].blockSignals(False) | ||||
|  | ||||
|     @QtCore.pyqtSlot(str, object) | ||||
| @@ -171,7 +175,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|         # disable single parameter if it is set global, enable if global is unset | ||||
|         widget = self.sender() | ||||
|         idx = self.global_parameter.index(widget) | ||||
|         enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None) | ||||
|         enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) | ||||
|         self.data_parameter[idx].setEnabled(enable) | ||||
|  | ||||
|     def select_next_preview(self, direction): | ||||
| @@ -204,64 +208,50 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|         if sid not in self.data_values: | ||||
|             self.data_values[sid] = [None] * len(self.data_parameter) | ||||
|  | ||||
|     def get_parameter(self, use_func=None): | ||||
|     def get_parameter(self, use_func=None) -> tuple[dict, list]: | ||||
|         bds = [] | ||||
|         is_global = [] | ||||
|         is_fixed = [] | ||||
|         globs = [] | ||||
|         is_linked = [] | ||||
|         param_general = [] | ||||
|  | ||||
|         for g in self.global_parameter: | ||||
|             if isinstance(g, FitModelWidget): | ||||
|                 p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter() | ||||
|                 p_i, bds_i, fixed_i, global_i = g.get_parameter() | ||||
|                 parameter_i = Parameter(name=g.name, value=p_i, lb=bds_i[0], ub=bds_i[1], var=fixed_i) | ||||
|                 param_general.append(parameter_i) | ||||
|  | ||||
|                 globs.append(p_i) | ||||
|                 bds.append(bds_i) | ||||
|                 is_fixed.append(fixed_i) | ||||
|                 is_global.append(global_i) | ||||
|                 is_linked.append(link_i) | ||||
|  | ||||
|         lb, ub = list(zip(*bds)) | ||||
|  | ||||
|         data_parameter = {} | ||||
|         if use_func is None: | ||||
|             use_func = list(self.data_values.keys()) | ||||
|  | ||||
|         global_p = None | ||||
|         for sid, parameter in self.data_values.items(): | ||||
|             if sid not in use_func: | ||||
|                 continue | ||||
|  | ||||
|             kw_p = {} | ||||
|             p = [] | ||||
|             if global_p is None: | ||||
|                 global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []} | ||||
|  | ||||
|             for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)): | ||||
|                 if isinstance(g, FitModelWidget): | ||||
|                     if (p_i is None) or is_global[i]: | ||||
|                         p.append(globs[i]) | ||||
|                         if is_global[i]: | ||||
|                             if i not in global_p['idx']: | ||||
|                                 global_p['p'].append(globs[i]) | ||||
|                                 global_p['idx'].append(i) | ||||
|                                 global_p['var'].append(is_fixed[i]) | ||||
|                                 global_p['ub'].append(ub[i]) | ||||
|                                 global_p['lb'].append(lb[i]) | ||||
|                         # set has no oen value | ||||
|                         p.append(param_general[i].copy()) | ||||
|                     else: | ||||
|                         p.append(p_i) | ||||
|                         lb, ub = bds[i] | ||||
|                         try: | ||||
|                             if not (lb < p_i < ub): | ||||
|                                 raise ValueError(f'Parameter {g.name} is outside bounds ({lb}, {ub})') | ||||
|                         except TypeError: | ||||
|                             pass | ||||
|  | ||||
|                     try: | ||||
|                         if p[i] > ub[i]: | ||||
|                             raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') | ||||
|                     except TypeError: | ||||
|                         pass | ||||
|  | ||||
|                     try: | ||||
|                         if p[i] < lb[i]: | ||||
|                             raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') | ||||
|                     except TypeError: | ||||
|                         pass | ||||
|                         # create Parameter | ||||
|                         p.append( | ||||
|                             Parameter(name=g.name, value=p_i, lb=lb, ub=ub, var=is_fixed[i]) | ||||
|                         ) | ||||
|  | ||||
|                 else: | ||||
|                     if p_i is None: | ||||
| @@ -273,7 +263,15 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): | ||||
|  | ||||
|             data_parameter[sid] = (p, kw_p) | ||||
|  | ||||
|         return data_parameter, lb, ub, is_fixed, global_p, is_linked | ||||
|         global_parameter = [] | ||||
|         for param, global_flag in zip(param_general, is_global): | ||||
|             if global_flag: | ||||
|                 param.is_global = True | ||||
|                 global_parameter.append(param) | ||||
|             else: | ||||
|                 global_parameter.append(None) | ||||
|  | ||||
|         return data_parameter, global_parameter | ||||
|  | ||||
|     def set_parameter(self, set_id: str | None, parameter: list[float]) -> int: | ||||
|         num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter)) | ||||
| @@ -304,12 +302,12 @@ class ParameterSingleWidget(QtWidgets.QWidget): | ||||
|  | ||||
|         self._init_ui() | ||||
|  | ||||
|         self._name = name | ||||
|         self.name = name | ||||
|         self.label.setText(convert(name)) | ||||
|         self.label.setToolTip('If this is bold then this parameter is only for this data. ' | ||||
|                               'Otherwise, the general parameter is used and displayed') | ||||
|  | ||||
|         self.value_line.setValidator(QtGui.QDoubleValidator()) | ||||
|         # self.value_line.setValidator(QtGui.QDoubleValidator()) | ||||
|         self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0) | ||||
|         self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit()) | ||||
|  | ||||
| @@ -343,9 +341,10 @@ class ParameterSingleWidget(QtWidgets.QWidget): | ||||
|  | ||||
|     @value.setter | ||||
|     def value(self, val): | ||||
|         self.value_line.setText(f'{float(val):.5g}') | ||||
|         # self.value_line.setText(f'{float(val):.5g}') | ||||
|         self.value_line.setText(f'{val}') | ||||
|  | ||||
|     def show_as_local_parameter(self, is_local): | ||||
|     def show_as_local_parameter(self, is_local: bool): | ||||
|         if is_local: | ||||
|             self.label.setStyleSheet('font-weight: bold;') | ||||
|         else: | ||||
|   | ||||
| @@ -128,7 +128,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form): | ||||
|  | ||||
|         self.newFunction.emit(idx, cnt) | ||||
|  | ||||
|         self.add_function(idx, cnt, op, name, col) | ||||
|         self.add_function(idx, cnt, op, name, col, param_names=self.functions[idx].params) | ||||
|  | ||||
|     def add_function(self, idx: int, cnt: int, op: int, | ||||
|                      name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs): | ||||
| @@ -141,6 +141,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form): | ||||
|             qcolor = QtGui.QColor.fromRgbF(*color) | ||||
|         else: | ||||
|             qcolor = QtGui.QColor(color) | ||||
|  | ||||
|         self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs) | ||||
|  | ||||
|         f = self.functions[idx] | ||||
|   | ||||
| @@ -9,6 +9,9 @@ import numpy as np | ||||
| from pyqtgraph import mkPen | ||||
|  | ||||
| from nmreval.fit._meta import MultiModel, ModelFactory | ||||
| from nmreval.fit.data import Data | ||||
| from nmreval.fit.model import Model | ||||
| from nmreval.fit.parameter import Parameters | ||||
| from nmreval.fit.result import FitResult | ||||
|  | ||||
| from .fit_forms import FitTableWidget | ||||
| @@ -116,7 +119,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|  | ||||
|         # collect parameter names etc. to allow linkage | ||||
|         self._func_list[self._current_model] = self.functionwidget.get_parameter_list() | ||||
|         dialog.set_links(self._func_list) | ||||
|         # dialog.set_links(self._func_list) | ||||
|  | ||||
|         # show same tab (general parameter/Data parameter) | ||||
|         tab_idx = 0 | ||||
| @@ -219,57 +222,49 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|  | ||||
|     def _prepare(self, model: list, function_use: list = None, | ||||
|                  parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: | ||||
|  | ||||
|         if parameter is None: | ||||
|             parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [], | ||||
|                          'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []}, | ||||
|                          'links': [], 'color': []} | ||||
|             parameter = { | ||||
|                 'data_parameter': {}, | ||||
|                 'global_parameter': [], | ||||
|                 'links': [], | ||||
|                 'color': [], | ||||
|             } | ||||
|  | ||||
|         for i, f in enumerate(model): | ||||
|             if not f['active']: | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 p, lb, ub, var, glob, links = self.param_widgets[f['cnt']].get_parameter(function_use) | ||||
|                 p, glob = self.param_widgets[f['cnt']].get_parameter(function_use) | ||||
|             except ValueError as e: | ||||
|                 _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e), | ||||
|                                                     QtWidgets.QMessageBox.Ok) | ||||
|                 return {}, -1 | ||||
|  | ||||
|             p_len = len(parameter['lb']) | ||||
|  | ||||
|             parameter['lb'] += lb | ||||
|             parameter['ub'] += ub | ||||
|             parameter['var'] += var | ||||
|             parameter['links'] += links | ||||
|             parameter['color'] += [f['color']] | ||||
|             parameter['color'].append(f['color']) | ||||
|             parameter['global_parameter'].extend(glob) | ||||
|  | ||||
|             cnt = f['cnt'] | ||||
|  | ||||
|             for p_k, v_k in p.items(): | ||||
|                 if add_idx: | ||||
|                     kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} | ||||
|                 else: | ||||
|                     kw_k = v_k[1] | ||||
|  | ||||
|                 if p_k in parameter['parameter']: | ||||
|                     params, kw = parameter['parameter'][p_k] | ||||
|                 if p_k in parameter['data_parameter']: | ||||
|                     params, kw = parameter['data_parameter'][p_k] | ||||
|                     params += v_k[0] | ||||
|                     kw.update(kw_k) | ||||
|                 else: | ||||
|                     parameter['parameter'][p_k] = (v_k[0], kw_k) | ||||
|  | ||||
|             for g_k, g_v in glob.items(): | ||||
|                 if g_k != 'idx': | ||||
|                     parameter['glob'][g_k] += g_v | ||||
|                 else: | ||||
|                     parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v] | ||||
|                     parameter['data_parameter'][p_k] = (v_k[0], kw_k) | ||||
|  | ||||
|             if add_idx: | ||||
|                 cnt += 1 | ||||
|  | ||||
|             if f['children']: | ||||
|                 # recurse for children | ||||
|                 child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) | ||||
|                 _, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) | ||||
|  | ||||
|         return parameter, cnt | ||||
|  | ||||
| @@ -280,30 +275,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|         data = self.data_table.collect_data(default=self.default_combobox.currentData()) | ||||
|  | ||||
|         func_dict = {} | ||||
|         for k, mod in self.models.items(): | ||||
|             func, order, param_len = ModelFactory.create_from_list(mod) | ||||
|         for model_name, model_parameter in self.models.items(): | ||||
|             func, order, param_len = ModelFactory.create_from_list(model_parameter) | ||||
|  | ||||
|             if func is None: | ||||
|                 continue | ||||
|  | ||||
|             if k in data: | ||||
|                 parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel)) | ||||
|             func = Model(func) | ||||
|  | ||||
|             if model_name in data: | ||||
|                 parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel)) | ||||
|  | ||||
|                 if parameter is None: | ||||
|                     return | ||||
|  | ||||
|                 for (data_parameter, _) in parameter['data_parameter'].values(): | ||||
|                     for pname, param in zip(func.params, data_parameter): | ||||
|                         param.name = pname | ||||
|  | ||||
|                 if self._complex[model_name] is not None: | ||||
|                     for p_k, p_v in parameter['data_parameter'].items(): | ||||
|                         p_v[1].update({'complex_mode': self._complex[model_name]}) | ||||
|                         parameter['data_parameter'][p_k] = p_v[0], p_v[1] | ||||
|  | ||||
|                 for pname, param_value in zip(func.params, parameter['global_parameter']): | ||||
|                     if param_value is not None: | ||||
|                         param_value.name = pname | ||||
|                         func.set_global_parameter(param_value) | ||||
|  | ||||
|                 parameter['func'] = func | ||||
|                 parameter['order'] = order | ||||
|                 parameter['len'] = param_len | ||||
|                 parameter['complex'] = self._complex[k] | ||||
|                 if self._complex[k] is not None: | ||||
|                     for p_k, p_v in parameter['parameter'].items(): | ||||
|                         p_v[1].update({'complex_mode': self._complex[k]}) | ||||
|                         parameter['parameter'][p_k] = p_v[0], p_v[1] | ||||
|                 parameter['complex'] = self._complex[model_name] | ||||
|  | ||||
|                 func_dict[k] = parameter | ||||
|                 func_dict[model_name] = parameter | ||||
|  | ||||
|         replaceable = [] | ||||
|         for k, v in func_dict.items(): | ||||
|         for model_name, v in func_dict.items(): | ||||
|             for i, link_i in enumerate(v['links']): | ||||
|                 if link_i is None: | ||||
|                     continue | ||||
| @@ -334,7 +342,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|                                                         QtWidgets.QMessageBox.Ok) | ||||
|                     return | ||||
|  | ||||
|                 replaceable.append((k, i, rep_model, repl_idx)) | ||||
|                 replaceable.append((model_name, i, rep_model, repl_idx)) | ||||
|  | ||||
|                 replace_value = None | ||||
|                 for p_k in f['parameter'].values(): | ||||
| @@ -412,31 +420,37 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|     def make_previews(self, x, models_parameters: dict): | ||||
|         self.preview_lines = [] | ||||
|  | ||||
|         # needed to create namespace | ||||
|         param_dict = Parameters() | ||||
|  | ||||
|         cnt = 0 | ||||
|         for model in models_parameters.values(): | ||||
|             f = model['func'] | ||||
|             for parameter_list in model['data_parameter'].values(): | ||||
|                 for i, p_value in enumerate(parameter_list[0]): | ||||
|                     p_value.name = f.params[i] | ||||
|                     param_dict.add_parameter(f'a{cnt}', p_value) | ||||
|                     cnt += 1 | ||||
|  | ||||
|         for k, model in models_parameters.items(): | ||||
|             f = model['func'] | ||||
|             is_complex = self._complex[k] | ||||
|  | ||||
|             parameters = model['parameter'] | ||||
|             parameters = model['data_parameter'] | ||||
|             color = model['color'] | ||||
|  | ||||
|             seen_parameter = [] | ||||
|  | ||||
|             for p, kwargs in parameters.values(): | ||||
|                 if (p, kwargs) in seen_parameter: | ||||
|                     # plot only previews with different parameter | ||||
|                     continue | ||||
|  | ||||
|                 seen_parameter.append((p, kwargs)) | ||||
|                 p_value = [pp.value for pp in p] | ||||
|  | ||||
|                 if is_complex is not None: | ||||
|                     y = f.func(x, *p, complex_mode=is_complex, **kwargs) | ||||
|                     y = f.func(x, *p_value, complex_mode=is_complex, **kwargs) | ||||
|                     if np.iscomplexobj(y): | ||||
|                         self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) | ||||
|                         self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) | ||||
|                     else: | ||||
|                         self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) | ||||
|                 else: | ||||
|                     y = f.func(x, *p, **kwargs) | ||||
|                     y = f.func(x, *p_value, **kwargs) | ||||
|                     self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) | ||||
|  | ||||
|                 if isinstance(f, MultiModel): | ||||
| @@ -444,7 +458,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|                     if is_complex is not None: | ||||
|                         sub_kwargs.update({'complex_mode': is_complex}) | ||||
|  | ||||
|                     for i, s in enumerate(f.subs(x, *p, **sub_kwargs)): | ||||
|                     for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)): | ||||
|                         pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i])) | ||||
|                         if np.iscomplexobj(s): | ||||
|                             self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i)) | ||||
| @@ -452,15 +466,17 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): | ||||
|                         else: | ||||
|                             self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) | ||||
|  | ||||
|         param_dict.clear() | ||||
|  | ||||
|         return self.preview_lines | ||||
|  | ||||
|     def set_parameter(self, parameter: dict[str, FitResult]): | ||||
|         # which data uses which model | ||||
|         data = self.data_table.collect_data(default=self.default_combobox.currentData()) | ||||
|  | ||||
|         glob_fit_parameter = [] | ||||
|  | ||||
|         for fitted_model, fitted_data in data.items(): | ||||
|             glob_fit_parameter = [] | ||||
|  | ||||
|             for fit_id, fit_curve in parameter.items(): | ||||
|                 if fit_id in fitted_data: | ||||
|                     fit_parameter = list(fit_curve.parameter.values()) | ||||
|   | ||||
| @@ -138,9 +138,7 @@ class DrawingsWidget(QtWidgets.QWidget, Ui_Form): | ||||
|         graph_id = self.graph_comboBox.currentData() | ||||
|         current_lines = self.lines[graph_id] | ||||
|  | ||||
|         print(remove_rows) | ||||
|         for i in reversed(remove_rows): | ||||
|             print(i) | ||||
|             self.tableWidget.removeRow(i) | ||||
|             self.line_deleted.emit(current_lines[i], graph_id) | ||||
|  | ||||
|   | ||||
| @@ -27,7 +27,6 @@ class MdiAreaTile(QtWidgets.QMdiArea): | ||||
|         pos = QtCore.QPoint(0, 0) | ||||
|  | ||||
|         for win in window_list: | ||||
|             print(win.minimumSize()) | ||||
|             win.setGeometry(rect) | ||||
|             win.move(pos) | ||||
|  | ||||
|   | ||||
| @@ -1,110 +0,0 @@ | ||||
| import os.path | ||||
| import json | ||||
| import urllib.request | ||||
| import webbrowser | ||||
| import random | ||||
|  | ||||
| from ..Qt import QtGui, QtCore, QtWidgets | ||||
| from .._py.pokemon import Ui_Dialog | ||||
|  | ||||
| random.seed() | ||||
|  | ||||
|  | ||||
| class QPokemon(QtWidgets.QDialog, Ui_Dialog): | ||||
|     def __init__(self, number=None, parent=None): | ||||
|         super().__init__(parent=parent) | ||||
|         self.setupUi(self) | ||||
|         self._js = json.load(open(os.path.join(path_to_module, 'utils', 'pokemon.json'), 'r'), encoding='UTF-8') | ||||
|         self._id = 0 | ||||
|  | ||||
|         if number is not None and number in range(1, len(self._js)+1): | ||||
|             poke_nr = f'{number:03d}' | ||||
|             self._id = number | ||||
|         else: | ||||
|             poke_nr = f'{random.randint(1, len(self._js)):03d}' | ||||
|             self._id = int(poke_nr) | ||||
|  | ||||
|         self._pokemon = None | ||||
|         self.show_pokemon(poke_nr) | ||||
|         self.label_15.linkActivated.connect(lambda x: webbrowser.open(x)) | ||||
|  | ||||
|         self.buttonBox.clicked.connect(self.randomize) | ||||
|         self.next_button.clicked.connect(self.show_next) | ||||
|         self.prev_button.clicked.connect(self.show_prev) | ||||
|  | ||||
|     def show_pokemon(self, nr): | ||||
|         self._pokemon = self._js[nr] | ||||
|         self.setWindowTitle('Pokémon: ' + self._pokemon['Deutsch']) | ||||
|  | ||||
|         for i in range(self.tabWidget.count(), -1, -1): | ||||
|             print('i', self.tabWidget.count(), i) | ||||
|             try: | ||||
|                 self.tabWidget.widget(i).deleteLater() | ||||
|             except AttributeError: | ||||
|                 pass | ||||
|  | ||||
|         for n, img in self._pokemon['Bilder']: | ||||
|             w = QtWidgets.QWidget() | ||||
|             vl = QtWidgets.QVBoxLayout() | ||||
|             l = QtWidgets.QLabel(self) | ||||
|             l.setAlignment(QtCore.Qt.AlignHCenter) | ||||
|             pixmap = QtGui.QPixmap() | ||||
|  | ||||
|             try: | ||||
|                 pixmap.loadFromData(urllib.request.urlopen(img, timeout=0.5).read()) | ||||
|             except IOError: | ||||
|                 l.setText(n) | ||||
|             else: | ||||
|                 sc_pixmap = pixmap.scaled(256, 256, QtCore.Qt.KeepAspectRatio) | ||||
|                 l.setPixmap(sc_pixmap) | ||||
|  | ||||
|             vl.addWidget(l) | ||||
|             w.setLayout(vl) | ||||
|             self.tabWidget.addTab(w, n) | ||||
|  | ||||
|         if len(self._pokemon['Bilder']) <= 1: | ||||
|             self.tabWidget.tabBar().setVisible(False) | ||||
|         else: | ||||
|             self.tabWidget.tabBar().setVisible(True) | ||||
|         self.tabWidget.adjustSize() | ||||
|  | ||||
|         self.name.clear() | ||||
|         keys = ['National-Dex', 'Kategorie', 'Typ', 'Größe', 'Gewicht', 'Farbe', 'Link'] | ||||
|         label_list = [self.pokedex_nr, self.category, self.poketype, self.weight, self.height, self.color, self.info] | ||||
|         for (k, label) in zip(keys, label_list): | ||||
|             v = self._pokemon[k] | ||||
|             if isinstance(v, list): | ||||
|                 v = os.path.join('', *v) | ||||
|  | ||||
|             if k == 'Link': | ||||
|                 v = '<a href={}>{}</a>'.format(v, v) | ||||
|  | ||||
|             label.setText(v) | ||||
|  | ||||
|         for k in ['Deutsch', 'Japanisch', 'Englisch', 'Französisch']: | ||||
|             v = self._pokemon[k] | ||||
|             self.name.addItem(k + ': ' + v) | ||||
|  | ||||
|         self.adjustSize() | ||||
|  | ||||
|     def randomize(self, idd): | ||||
|         if idd.text() == 'Retry': | ||||
|             new_number = f'{random.randint(1, len(self._js)):03d}' | ||||
|             self._id = int(new_number) | ||||
|             self.show_pokemon(new_number) | ||||
|         else: | ||||
|             self.close() | ||||
|  | ||||
|     def show_next(self): | ||||
|         new_number = self._id + 1 | ||||
|         if new_number > len(self._js): | ||||
|             new_number = 1 | ||||
|         self._id = new_number | ||||
|         self.show_pokemon(f'{new_number:03d}') | ||||
|  | ||||
|     def show_prev(self): | ||||
|         new_number = self._id - 1 | ||||
|         if new_number == 0: | ||||
|             new_number = len(self._js) | ||||
|         self._id = new_number | ||||
|         self.show_pokemon(f'{new_number:03d}') | ||||
| @@ -441,7 +441,7 @@ class UpperManagement(QtCore.QObject): | ||||
|         # all-encompassing error catch | ||||
|         try: | ||||
|             for model_id, model_p in parameter.items(): | ||||
|                 m = Model(model_p['func']) | ||||
|                 m = model_p['func'] | ||||
|                 models[model_id] = m | ||||
|  | ||||
|                 m_complex = model_p['complex'] | ||||
| @@ -450,13 +450,16 @@ class UpperManagement(QtCore.QObject): | ||||
|                 # iterate over order of set id in active order and access parameter inside loop | ||||
|                 # instead of directly looping | ||||
|                 try: | ||||
|                     list_ids = list(model_p['parameter'].keys()) | ||||
|                     list_ids = list(model_p['data_parameter'].keys()) | ||||
|                     set_order = [self.active_id.index(i) for i in list_ids] | ||||
|                 except ValueError as e: | ||||
|                     raise Exception('Getting order failed') from e | ||||
|  | ||||
|                 for pos in set_order: | ||||
|                     set_id = list_ids[pos] | ||||
|  | ||||
|                     data_i = self.data[set_id] | ||||
|                     set_params = model_p['data_parameter'][set_id] | ||||
|                     try: | ||||
|                         data_i = self.data[set_id] | ||||
|                     except KeyError as e: | ||||
| @@ -488,7 +491,7 @@ class UpperManagement(QtCore.QObject): | ||||
|                         inside = np.where((_x >= x_lim[0]) & (_x <= x_lim[1])) | ||||
|                     else: | ||||
|                         inside = np.where((_x >= fit_limits[0]) & (_x <= fit_limits[1])) | ||||
|                          | ||||
|  | ||||
|                     try: | ||||
|                         if isinstance(we, str): | ||||
|                             d = fit_d.Data(_x[inside], _y[inside], we=we, idx=set_id) | ||||
| @@ -499,18 +502,12 @@ class UpperManagement(QtCore.QObject): | ||||
|  | ||||
|                     d.set_model(m) | ||||
|                     try: | ||||
|                         d.set_parameter(set_params[0], var=model_p['var'], | ||||
|                                         lb=model_p['lb'], ub=model_p['ub'], | ||||
|                                         fun_kwargs=set_params[1]) | ||||
|                         d.set_parameter(set_params[0], fun_kwargs=set_params[1]) | ||||
|                     except Exception as e: | ||||
|                         raise Exception('Setting parameter failed') from e | ||||
|  | ||||
|                     self.fitter.add_data(d) | ||||
|  | ||||
|                 model_globs = model_p['glob'] | ||||
|                 if model_globs: | ||||
|                     m.set_global_parameter(**model_p['glob']) | ||||
|  | ||||
|             for links_i in links: | ||||
|                 self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), | ||||
|                                                (models[links_i[2]], links_i[3])) | ||||
| @@ -1170,7 +1167,6 @@ class UpperManagement(QtCore.QObject): | ||||
|  | ||||
|     @QtCore.pyqtSlot(dict) | ||||
|     def calc_relaxation(self, opts: dict): | ||||
|  | ||||
|         params = opts['pts'] | ||||
|         if len(params) == 4: | ||||
|             if params[3]: | ||||
|   | ||||
| @@ -1,7 +1,9 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from .model import Model | ||||
| from .parameter import Parameters | ||||
| from .parameter import Parameters, Parameter | ||||
|  | ||||
|  | ||||
| class Data(object): | ||||
| @@ -16,7 +18,7 @@ class Data(object): | ||||
|         self.model = None | ||||
|         self.minimizer = None | ||||
|         self.parameter = Parameters() | ||||
|         self.para_keys = None | ||||
|         self.para_keys: list = [] | ||||
|         self.fun_kwargs = {} | ||||
|  | ||||
|     def __len__(self): | ||||
| @@ -68,12 +70,19 @@ class Data(object): | ||||
|     def get_model(self): | ||||
|         return self.model | ||||
|  | ||||
|     def set_parameter(self, parameter, var=None, ub=None, lb=None, | ||||
|                       default_bounds=False, fun_kwargs=None): | ||||
|     def set_parameter(self, | ||||
|                       values: list[float | Parameter], | ||||
|                       *, | ||||
|                       var: list[bool] = None, | ||||
|                       ub: list[float] = None, | ||||
|                       lb: list[float] = None, | ||||
|                       default_bounds: bool = False, | ||||
|                       fun_kwargs: dict = None | ||||
|                       ): | ||||
|         """ | ||||
|         Creates parameter for this data. | ||||
|         If no Model is available, it falls back to the model | ||||
|         :param parameter: list of parameters | ||||
|         :param values: list of parameters | ||||
|         :param var: list of boolean or boolean; False fixes parameter at given list index. | ||||
|                     Single value is broadcast to all parameter | ||||
|         :param ub: list of upper boundaries or float; Single value is broadcast to all parameter. | ||||
| @@ -87,23 +96,46 @@ class Data(object): | ||||
|         model = self.model | ||||
|         if model is None: | ||||
|             # Data has no unique | ||||
|             if self.minimizer is None: | ||||
|                 model = None | ||||
|             else: | ||||
|             if self.minimizer is not None: | ||||
|                 model = self.minimizer.fit_model | ||||
|                 self.fun_kwargs.update(model.fun_kwargs) | ||||
|  | ||||
|         if model is None: | ||||
|             raise ValueError('No model found, please set model before parameters') | ||||
|  | ||||
|         if default_bounds: | ||||
|         if len(values) != len(model.params): | ||||
|             raise ValueError('Number of given parameter does not match number of model parameters') | ||||
|  | ||||
|         is_parameter = [isinstance(v, Parameter) for v in values] | ||||
|         if all(is_parameter): | ||||
|             for p_i in values: | ||||
|                 key = f"p{next(Parameters.parameter_counter)}" | ||||
|                 self.parameter.add_parameter(key, p_i) | ||||
|         elif any(is_parameter): | ||||
|             raise ValueError('list of parameter are not all float of Parameter') | ||||
|  | ||||
|         else: | ||||
|             if var is None: | ||||
|                 var = [True] * len(values) | ||||
|  | ||||
|             if lb is None: | ||||
|                 lb = model.lb | ||||
|                 if default_bounds: | ||||
|                     lb = model.lb | ||||
|                 else: | ||||
|                     lb = [None] * len(values) | ||||
|  | ||||
|             if ub is None: | ||||
|                 ub = model.ub | ||||
|                 if default_bounds: | ||||
|                     ub = model.ub | ||||
|                 else: | ||||
|                     ub = [None] * len(values) | ||||
|  | ||||
|         self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub) | ||||
|             arg_names = ['name', 'value', 'var', 'lb', 'ub'] | ||||
|             for parameter_arg in zip(model.params, values, var, lb, ub): | ||||
|                 self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)}) | ||||
|  | ||||
|         self.para_keys = list(self.parameter.keys()) | ||||
|  | ||||
|         self.fun_kwargs.update(model.fun_kwargs) | ||||
|         if fun_kwargs is not None: | ||||
|             self.fun_kwargs.update(fun_kwargs) | ||||
|  | ||||
| @@ -123,6 +155,18 @@ class Data(object): | ||||
|         else: | ||||
|             return [p.value for p in self.minimizer.parameters[self.parameter]] | ||||
|  | ||||
|     def replace_parameter(self, key: str, parameter: Parameter) -> None: | ||||
|         tobereplaced = None | ||||
|         for k, v in self.parameter.items(): | ||||
|             if v.name == parameter.name: | ||||
|                 tobereplaced = k | ||||
|                 break | ||||
|  | ||||
|         if tobereplaced is None: | ||||
|             raise KeyError(f'Global parameter {key} not found in list of parameters') | ||||
|         self.para_keys[self.para_keys.index(tobereplaced)] = key | ||||
|         self.parameter.replace_parameter(tobereplaced, key, parameter) | ||||
|  | ||||
|     def cost(self, p): | ||||
|         """ | ||||
|         Cost function :math:`y-f(p, x)` | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import warnings | ||||
| from itertools import product | ||||
|  | ||||
| @@ -21,13 +23,70 @@ class FitAbortException(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| # COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR) | ||||
| def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_pars: list[list[str]]): | ||||
|     # replace values | ||||
|     for keys, values in zip(varpars, p): | ||||
|         for data_i in data: | ||||
|             if keys in data_i.parameter.keys(): | ||||
|                 # TODO move this to scaled_value setter | ||||
|                 data_i.parameter[keys].scaled_value = values | ||||
|                 data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value | ||||
|     r = [] | ||||
|     # unpack parameter and calculate y values and concatenate all | ||||
|     for values, p_idx in zip(data, used_pars): | ||||
|         actual_parameters = [values.parameter[keys].value for keys in p_idx] | ||||
|         r = np.r_[r, values.cost(actual_parameters)] | ||||
|  | ||||
|     return r | ||||
|  | ||||
|  | ||||
| def _cost_scipy(p, data, varpars, used_pars): | ||||
|     for keys, values in zip(varpars, p): | ||||
|         data.parameter[keys].scaled_value = values | ||||
|         data.parameter[keys].namespace[keys] = data.parameter[keys].value | ||||
|  | ||||
|     actual_parameters = [data.parameter[keys].value for keys in used_pars] | ||||
|     return data.cost(actual_parameters) | ||||
|  | ||||
|  | ||||
| def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0): | ||||
|     for keys, values in zip(varpars, p): | ||||
|         data.parameter[keys].scaled_value = values | ||||
|         data.parameter[keys].namespace[keys] = data.parameter[keys].value | ||||
|  | ||||
|     actual_parameters = [data.parameter[keys].value for keys in used_pars] | ||||
|  | ||||
|     return data.func(actual_parameters, data.x) | ||||
|  | ||||
|  | ||||
| def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]): | ||||
|     # replace values | ||||
|     for data_i in data: | ||||
|         _update_parameter(data_i, var_pars, p) | ||||
|  | ||||
|     r = [] | ||||
|     # unpack parameter and calculate y values and concatenate all | ||||
|     for values, p_idx in zip(data, used_pars): | ||||
|         actual_parameters = [values.parameter[keys].value for keys in p_idx] | ||||
|         r = np.r_[r, values.func(actual_parameters, values.x)] | ||||
|  | ||||
|     return r | ||||
|  | ||||
|  | ||||
| def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]): | ||||
|     for keys, values in zip(varied_keys, parameter): | ||||
|         if keys in data.parameter.keys(): | ||||
|             data.parameter[keys].scaled_value = values | ||||
|             data.parameter[keys].namespace[keys] = data.parameter[keys].value | ||||
|  | ||||
|  | ||||
| class FitRoutine(object): | ||||
|     def __init__(self, mode='lsq'): | ||||
|         self.fitmethod = mode | ||||
|         self.data = [] | ||||
|         self.fit_model = None | ||||
|         self._no_own_model = [] | ||||
|         self.parameter = Parameters() | ||||
|         self.result = [] | ||||
|         self.linked = [] | ||||
|         self._abort = False | ||||
| @@ -81,29 +140,27 @@ class FitRoutine(object): | ||||
|  | ||||
|         return self.fit_model | ||||
|  | ||||
|     def set_link_parameter(self, parameter: tuple, replacement: tuple): | ||||
|     def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]): | ||||
|         if isinstance(replacement[0], Model): | ||||
|             if replacement[1] not in replacement[0].global_parameter: | ||||
|                 raise KeyError(f'Parameter at pos {replacement[1]} of ' | ||||
|                                f'model {str(replacement[0])} is not global') | ||||
|             if replacement[1] not in replacement[0].parameter: | ||||
|                 raise KeyError(f'Parameter {replacement[1]} of ' | ||||
|                                f'model {replacement[0]} is not global') | ||||
|  | ||||
|         if isinstance(parameter[0], Model): | ||||
|             warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} ' | ||||
|         if isinstance(dismissed_param[0], Model): | ||||
|             warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} ' | ||||
|                           f'becomes global with linkage.') | ||||
|  | ||||
|         self.linked.append((*parameter, *replacement)) | ||||
|         self.linked.append((*dismissed_param, *replacement)) | ||||
|  | ||||
|     def prepare_links(self): | ||||
|         self._no_own_model = [] | ||||
|         self.parameter = Parameters() | ||||
|         _found_models = {} | ||||
|         linked_sender = {} | ||||
|  | ||||
|         for v in self.data: | ||||
|             linked_sender[v] = set() | ||||
|             self.parameter.update(v.parameter.copy()) | ||||
|  | ||||
|             # set temporaray model | ||||
|             # set temporary model | ||||
|             if v.model is None: | ||||
|                 v.model = self.fit_model | ||||
|                 self._no_own_model.append(v) | ||||
| @@ -111,8 +168,6 @@ class FitRoutine(object): | ||||
|             # register model | ||||
|             if v.model not in _found_models: | ||||
|                 _found_models[v.model] = [] | ||||
|                 m_param = v.model.parameter.copy() | ||||
|                 self.parameter.update(m_param) | ||||
|  | ||||
|             _found_models[v.model].append(v) | ||||
|  | ||||
| @@ -120,24 +175,21 @@ class FitRoutine(object): | ||||
|                 linked_sender[v.model] = set() | ||||
|  | ||||
|         linked_parameter = {} | ||||
|         for par, par_parm, repl, repl_par in self.linked: | ||||
|             if isinstance(par, Data): | ||||
|                 if isinstance(repl, Data): | ||||
|                     linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par] | ||||
|                 else: | ||||
|                     linked_parameter[par.para_keys[par_parm]] = repl.global_parameter[repl_par] | ||||
|         for dismiss_model, dismiss_param, replace_model, replace_param in self.linked: | ||||
|             linked_sender[replace_model].add(dismiss_model) | ||||
|             linked_sender[replace_model].add(replace_model) | ||||
|  | ||||
|             replace_key = replace_model.parameter.get_key(replace_param) | ||||
|             dismiss_key = dismiss_model.parameter.get_key(dismiss_param) | ||||
|  | ||||
|             if isinstance(replace_model, Data): | ||||
|                 linked_parameter[dismiss_key] = replace_key | ||||
|             else: | ||||
|                 if isinstance(repl, Data): | ||||
|                     par.global_parameter[par_parm] = repl.para_keys[repl_par] | ||||
|                 else: | ||||
|                     par.global_parameter[par_parm] = repl.global_parameter[repl_par] | ||||
|  | ||||
|             linked_sender[repl].add(par) | ||||
|             linked_sender[par].add(repl) | ||||
|                 p = dismiss_model.set_global_parameter(dismiss_param, replace_key) | ||||
|                 p._expr_disp = replace_param | ||||
|  | ||||
|         for mm, m_data in _found_models.items(): | ||||
|             if mm.global_parameter: | ||||
|             if mm.parameter: | ||||
|                 for dd in m_data: | ||||
|                     linked_sender[mm].add(dd) | ||||
|                     linked_sender[dd].add(mm) | ||||
| @@ -169,15 +221,13 @@ class FitRoutine(object): | ||||
|         logger.info('Fit aborted by user') | ||||
|         self._abort = True | ||||
|  | ||||
|     def run(self, mode: str=None): | ||||
|     def run(self, mode: str = None): | ||||
|         self._abort = False | ||||
|         self.parameter = Parameters() | ||||
|  | ||||
|         if mode is None: | ||||
|             mode = self.fitmethod | ||||
|  | ||||
|         fit_groups, linked_parameter = self.prepare_links() | ||||
|  | ||||
|         for data_groups in fit_groups: | ||||
|             if len(data_groups) == 1 and not self.linked: | ||||
|                 data = data_groups[0] | ||||
| @@ -208,8 +258,21 @@ class FitRoutine(object): | ||||
|  | ||||
|         self.unprep_run() | ||||
|  | ||||
|         for r in self.result: | ||||
|             r.pprint() | ||||
|  | ||||
|         return self.result | ||||
|  | ||||
|     def make_preview(self, x: np.ndarray) -> list[np.ndarray]: | ||||
|         y_pred = [] | ||||
|         fit_groups, linked_parameter = self.prepare_links() | ||||
|         for data_groups in fit_groups: | ||||
|             data = data_groups[0] | ||||
|             actual_parameters = [p.value for p in data.parameter.values()] | ||||
|             y_pred.append(data.func(actual_parameters, x)) | ||||
|  | ||||
|         return y_pred | ||||
|  | ||||
|     def _prep_data(self, data): | ||||
|         if data.get_model() is None: | ||||
|             data._model = self.fit_model | ||||
| @@ -237,22 +300,16 @@ class FitRoutine(object): | ||||
|         var = [] | ||||
|         data_pars = [] | ||||
|  | ||||
|         # loopyloop over data that belong to one fit (linked or global) | ||||
|         # loopy-loop over data that belong to one fit (linked or global) | ||||
|         for data in data_group: | ||||
|             # is parameter replaced by global parameter? | ||||
|             for k, v in data.model.parameter.items(): | ||||
|                 data.replace_parameter(k, v) | ||||
|  | ||||
|             actual_pars = [] | ||||
|             for i, (p_k, v_k) in enumerate(data.parameter.items()): | ||||
|             for i, p_k in enumerate(data.para_keys): | ||||
|                 p_k_used = p_k | ||||
|                 v_k_used = v_k | ||||
|  | ||||
|                 # is parameter replaced by global parameter? | ||||
|                 if i in data.model.global_parameter: | ||||
|                     p_k_used = data.model.global_parameter[i] | ||||
|                     v_k_used = self.parameter[p_k_used] | ||||
|  | ||||
|                 # links trump global parameter | ||||
|                 if p_k_used in linked: | ||||
|                     p_k_used = linked[p_k_used] | ||||
|                     v_k_used = self.parameter[p_k_used] | ||||
|                 v_k_used = data.parameter[p_k] | ||||
|  | ||||
|                 actual_pars.append(p_k_used) | ||||
|                 # parameter is variable and was not found before as shared parameter | ||||
| @@ -271,48 +328,7 @@ class FitRoutine(object): | ||||
|             d._model = None | ||||
|  | ||||
|         self._no_own_model = [] | ||||
|  | ||||
|     # COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR) | ||||
|     def __cost_scipy(self, p, data, varpars, used_pars): | ||||
|         for keys, values in zip(varpars, p): | ||||
|             self.parameter[keys].scaled_value = values | ||||
|  | ||||
|         actual_parameters = [self.parameter[keys].value for keys in used_pars] | ||||
|         return data.cost(actual_parameters) | ||||
|  | ||||
|     def __cost_odr(self, p, data, varpars, used_pars): | ||||
|         for keys, values in zip(varpars, p): | ||||
|             self.parameter[keys].scaled_value = values | ||||
|  | ||||
|         actual_parameters = [self.parameter[keys].value for keys in used_pars] | ||||
|  | ||||
|         return data.func(actual_parameters, data.x) | ||||
|  | ||||
|     def __cost_scipy_glob(self, p, data, varpars, used_pars): | ||||
|         # replace values | ||||
|         for keys, values in zip(varpars, p): | ||||
|             self.parameter[keys].scaled_value = values | ||||
|  | ||||
|         r = [] | ||||
|         # unpack parameter and calculate y values and concatenate all | ||||
|         for values, p_idx in zip(data, used_pars): | ||||
|             actual_parameters = [self.parameter[keys].value for keys in p_idx] | ||||
|             r = np.r_[r, values.cost(actual_parameters)] | ||||
|  | ||||
|         return r | ||||
|  | ||||
|     def __cost_odr_glob(self, p, data, varpars, used_pars): | ||||
|         # replace values | ||||
|         for keys, values in zip(varpars, p): | ||||
|             self.parameter[keys].scaled_value = values | ||||
|  | ||||
|         r = [] | ||||
|         # unpack parameter and calculate y values and concatenate all | ||||
|         for values, p_idx in zip(data, used_pars): | ||||
|             actual_parameters = [self.parameter[keys].value for keys in p_idx] | ||||
|             r = np.r_[r, values.func(actual_parameters, values.x)] | ||||
|  | ||||
|         return r | ||||
|         Parameters.reset() | ||||
|  | ||||
|     def _least_squares_single(self, data, p0, lb, ub, var): | ||||
|         self.step = 0 | ||||
| @@ -322,7 +338,7 @@ class FitRoutine(object): | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|  | ||||
|             return self.__cost_scipy(p, data, var, data.para_keys) | ||||
|             return _cost_scipy(p, data, var, data.para_keys) | ||||
|  | ||||
|         with np.errstate(all='ignore'): | ||||
|             res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) | ||||
| @@ -336,7 +352,7 @@ class FitRoutine(object): | ||||
|             self.step += 1 | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|             return self.__cost_scipy_glob(p, data, var, data_pars) | ||||
|             return _cost_scipy_glob(p, data, var, data_pars) | ||||
|  | ||||
|         with np.errstate(all='ignore'): | ||||
|             res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) | ||||
| @@ -351,7 +367,7 @@ class FitRoutine(object): | ||||
|             self.step += 1 | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|             return (self.__cost_scipy(p, data, var, data.para_keys)**2).sum() | ||||
|             return (_cost_scipy(p, data, var, data.para_keys) ** 2).sum() | ||||
|  | ||||
|         with np.errstate(all='ignore'): | ||||
|             res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], | ||||
| @@ -364,7 +380,7 @@ class FitRoutine(object): | ||||
|             self.step += 1 | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|             return (self.__cost_scipy_glob(p, data, var, data_pars)**2).sum() | ||||
|             return (_cost_scipy_glob(p, data, var, data_pars) ** 2).sum() | ||||
|  | ||||
|         with np.errstate(all='ignore'): | ||||
|             res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], | ||||
| @@ -380,13 +396,18 @@ class FitRoutine(object): | ||||
|             self.step += 1 | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|             return self.__cost_odr(p, data, var_pars, data.para_keys) | ||||
|             return _cost_odr(p, data, var_pars, data.para_keys) | ||||
|  | ||||
|         odr_model = odr.Model(func) | ||||
|  | ||||
|         corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0) | ||||
|  | ||||
|         self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)), | ||||
|                           err=res.sd_beta, corr=corr, partial_corr=partial_corr) | ||||
|  | ||||
|     def _odr_fit(self, odr_data, odr_model, p0): | ||||
|         o = odr.ODR(odr_data, odr_model, beta0=p0) | ||||
|         res = o.run() | ||||
|  | ||||
|         corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var | ||||
|         try: | ||||
|             corr_inv = np.linalg.inv(corr) | ||||
| @@ -395,16 +416,14 @@ class FitRoutine(object): | ||||
|             partial_corr[np.diag_indices_from(partial_corr)] = 1. | ||||
|         except np.linalg.LinAlgError: | ||||
|             partial_corr = corr | ||||
|  | ||||
|         self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)), | ||||
|                           err=res.sd_beta, corr=corr, partial_corr=partial_corr) | ||||
|         return corr, partial_corr, res | ||||
|  | ||||
|     def _odr_global(self, data, p0, var, data_pars): | ||||
|         def func(p, _): | ||||
|             self.step += 1 | ||||
|             if self._abort: | ||||
|                 raise FitAbortException(f'Fit aborted by user') | ||||
|             return self.__cost_odr_glob(p, data, var, data_pars) | ||||
|             return _cost_odr_glob(p, data, var, data_pars) | ||||
|  | ||||
|         x = [] | ||||
|         y = [] | ||||
| @@ -415,17 +434,7 @@ class FitRoutine(object): | ||||
|         odr_data = odr.Data(x, y) | ||||
|         odr_model = odr.Model(func) | ||||
|  | ||||
|         o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var) | ||||
|         res = o.run() | ||||
|  | ||||
|         corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var | ||||
|         try: | ||||
|             corr_inv = np.linalg.inv(corr) | ||||
|             corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv))) | ||||
|             partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag)  # Partial correlation matrix | ||||
|             partial_corr[np.diag_indices_from(partial_corr)] = 1. | ||||
|         except np.linalg.LinAlgError: | ||||
|             partial_corr = corr | ||||
|         corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0) | ||||
|  | ||||
|         for v, var_pars_k in zip(data, data_pars): | ||||
|             self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)), | ||||
| @@ -439,15 +448,17 @@ class FitRoutine(object): | ||||
|  | ||||
|         # update parameter values | ||||
|         for keys, p_value, err_value in zip(var_pars, p, err): | ||||
|             self.parameter[keys].scaled_value = p_value | ||||
|             self.parameter[keys].scaled_error = err_value | ||||
|             if keys in data.parameter.keys(): | ||||
|                 data.parameter[keys].scaled_value = p_value | ||||
|                 data.parameter[keys].scaled_error = err_value | ||||
|                 data.parameter[keys].namespace[keys] = data.parameter[keys].value | ||||
|  | ||||
|         combinations = list(product(var_pars, var_pars)) | ||||
|         actual_parameters = [] | ||||
|         corr_idx = [] | ||||
|  | ||||
|         for i, p_i in enumerate(used_pars): | ||||
|             actual_parameters.append(self.parameter[p_i]) | ||||
|             actual_parameters.append(data.parameter[p_i]) | ||||
|             for j, p_j in enumerate(used_pars): | ||||
|                 try: | ||||
|                     # find the position of the parameter combinations | ||||
| @@ -508,3 +519,4 @@ class FitRoutine(object): | ||||
|             partial_corr = corr | ||||
|  | ||||
|         return _err, corr, partial_corr | ||||
|  | ||||
|   | ||||
| @@ -6,7 +6,7 @@ from typing import Sized | ||||
| from numpy import inf | ||||
|  | ||||
| from ._meta import MultiModel | ||||
| from .parameter import Parameters | ||||
| from .parameter import Parameters, Parameter | ||||
|  | ||||
|  | ||||
| class Model(object): | ||||
| @@ -25,7 +25,6 @@ class Model(object): | ||||
|         self.ub = [i if i is not None else inf for i in self.ub] | ||||
|  | ||||
|         self.parameter = Parameters() | ||||
|         self.global_parameter = {} | ||||
|         self.is_complex = None | ||||
|         self._complex_part = False | ||||
|  | ||||
| @@ -80,23 +79,33 @@ class Model(object): | ||||
|             self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items() | ||||
|                                if v.default is not inspect.Parameter.empty} | ||||
|  | ||||
|     def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False): | ||||
|         if idx is None: | ||||
|             self.parameter = Parameters() | ||||
|             self.global_parameter = {} | ||||
|             return | ||||
|     def set_global_parameter(self, | ||||
|                              key: str | Parameter, | ||||
|                              value: float | str = None, | ||||
|                              *, | ||||
|                              var: bool = None, | ||||
|                              lb: float = None, | ||||
|                              ub: float = None, | ||||
|                              default_bounds: bool = False, | ||||
|                              ) -> Parameter: | ||||
|  | ||||
|         if default_bounds: | ||||
|             if lb is None: | ||||
|                 lb = [self.lb[i] for i in idx] | ||||
|             if ub is None: | ||||
|                 ub = [self.lb[i] for i in idx] | ||||
|         if isinstance(key, Parameter): | ||||
|             p = key | ||||
|             key = f'p{next(Parameters.parameter_counter)}' | ||||
|             self.parameter.add_parameter(key, p) | ||||
|  | ||||
|         gp = self.parameter.add_parameter(p, var=var, lb=lb, ub=ub) | ||||
|         for k, v in zip(idx, gp): | ||||
|             self.global_parameter[k] = v | ||||
|         else: | ||||
|             idx = [self.params.index(key)] | ||||
|             if default_bounds: | ||||
|                 if lb is None: | ||||
|                     lb = [self.lb[i] for i in idx] | ||||
|                 if ub is None: | ||||
|                     ub = [self.lb[i] for i in idx] | ||||
|  | ||||
|         return gp | ||||
|             p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) | ||||
|             p.is_global = True | ||||
|  | ||||
|         return p | ||||
|  | ||||
|     @staticmethod | ||||
|     def _prep(param_len, val): | ||||
|   | ||||
| @@ -1,124 +1,170 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from numbers import Number | ||||
| import re | ||||
| from itertools import count | ||||
|  | ||||
| from io import StringIO | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class Parameters(dict): | ||||
|     count = count() | ||||
|     parameter_counter = count() | ||||
|     # is one global namespace a good idea? | ||||
|     namespace: dict = {} | ||||
|  | ||||
|     def __str__(self): | ||||
|         return 'Parameters:\n' + '\n'.join([str(k)+': '+str(v) for k, v in self.items()]) | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._mapping: dict = {} | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         if isinstance(item, (list, tuple, np.ndarray)): | ||||
|             values = [] | ||||
|             for item_i in item: | ||||
|                 values.append(super().__getitem__(item_i)) | ||||
|             return values | ||||
|     def __str__(self) -> str: | ||||
|         return 'Parameters:\n' + '\n'.join([f'{k}: {v}' for k, v in self.items()]) | ||||
|  | ||||
|     def __getitem__(self, item) -> Parameter: | ||||
|         if item in self._mapping: | ||||
|             return super().__getitem__(self._mapping[item]) | ||||
|         else: | ||||
|             return super().__getitem__(item) | ||||
|  | ||||
|     def __setitem__(self, key, value): | ||||
|         self.add_parameter(key, value) | ||||
|  | ||||
|     def __contains__(self, item): | ||||
|         for v in self.values(): | ||||
|             if item == v.name: | ||||
|                 return True | ||||
|  | ||||
|         return False | ||||
|  | ||||
|     def add(self, | ||||
|             name: str, | ||||
|             value: str | float | int = None, | ||||
|             *, | ||||
|             var: bool = True, | ||||
|             lb: float = -np.inf, ub: float = np.inf) -> Parameter: | ||||
|  | ||||
|         par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub) | ||||
|         key = f'p{next(Parameters.parameter_counter)}' | ||||
|  | ||||
|         self.add_parameter(key, par) | ||||
|  | ||||
|         return par | ||||
|  | ||||
|     def add_parameter(self, key: str, parameter: Parameter): | ||||
|         self._mapping[parameter.name] = key | ||||
|         super().__setitem__(key, parameter) | ||||
|  | ||||
|         parameter.eval_allowed = False | ||||
|         self.namespace[key] = parameter.value | ||||
|         parameter.namespace = self.namespace | ||||
|         parameter.eval_allowed = True | ||||
|  | ||||
|         self.update_namespace() | ||||
|  | ||||
|     def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter): | ||||
|         self.add_parameter(key_in, parameter) | ||||
|         for k, v in self._mapping.items(): | ||||
|             if v == key_out: | ||||
|                 self._mapping[k] = key_in | ||||
|                 break | ||||
|  | ||||
|         if key_out in self.namespace: | ||||
|             del self.namespace[key_out] | ||||
|  | ||||
|     def fix(self): | ||||
|         for v in self.keys(): | ||||
|             v._value = v.value | ||||
|             v.namespace = {} | ||||
|  | ||||
|     @staticmethod | ||||
|     def _prep_bounds(val, p_len: int) -> list: | ||||
|         # helper function to ensure that bounds and variable are of parameter shape | ||||
|         if isinstance(val, (Number, bool)) or val is None: | ||||
|             return [val] * p_len | ||||
|     def reset(): | ||||
|         Parameters.namespace = {} | ||||
|  | ||||
|         elif len(val) == p_len: | ||||
|             return val | ||||
|  | ||||
|         elif len(val) == 1: | ||||
|             return [val[0]] * p_len | ||||
|  | ||||
|         else: | ||||
|             raise ValueError('Input {} has wrong dimensions'.format(val)) | ||||
|  | ||||
|     def add_parameter(self, param, var=None, lb=None, ub=None): | ||||
|         if isinstance(param, Number): | ||||
|             param = [param] | ||||
|  | ||||
|         p_len = len(param) | ||||
|  | ||||
|         # make list if only single value is given | ||||
|         var = self._prep_bounds(var, p_len) | ||||
|         lb = self._prep_bounds(lb, p_len) | ||||
|         ub = self._prep_bounds(ub, p_len) | ||||
|  | ||||
|         new_keys = [] | ||||
|         for i in range(p_len): | ||||
|             new_idx = next(self.count) | ||||
|             new_keys.append(new_idx) | ||||
|  | ||||
|             self[new_idx] = Parameter(param[i], var=var[i], lb=lb[i], ub=ub[i]) | ||||
|  | ||||
|         return new_keys | ||||
|  | ||||
|     def copy(self): | ||||
|         p = Parameters() | ||||
|     def get_key(self, name: str) -> str | None: | ||||
|         for k, v in self.items(): | ||||
|             p[k] = Parameter(v.value, var=v.var, lb=v.lb, ub=v.ub) | ||||
|             if name == v.name: | ||||
|                 return k | ||||
|  | ||||
|         if len(p) == 0: | ||||
|             return p | ||||
|  | ||||
|         max_k = max(p.keys()) | ||||
|         c = next(p.count) | ||||
|         while c < max_k: | ||||
|             c = next(p.count) | ||||
|  | ||||
|         return p | ||||
|         return | ||||
|  | ||||
|     def get_state(self): | ||||
|         return {k: v.get_state() for k, v in self.items()} | ||||
|  | ||||
|     def update_namespace(self): | ||||
|         for p in self.values(): | ||||
|             try: | ||||
|                 p.value | ||||
|             except NameError: | ||||
|                 expression = p._expr_disp | ||||
|                 for n, k in self._mapping.items(): | ||||
|                     expression, num_replaced = re.subn(re.escape(n), k, expression) | ||||
|  | ||||
|                 p._expr = expression | ||||
|  | ||||
|  | ||||
| class Parameter: | ||||
|     """ | ||||
|     Container for one parameter | ||||
|     """ | ||||
|     __slots__ = ['name', 'value', 'error', 'init_val', 'var', 'lb', 'ub', 'scale', 'function'] | ||||
|  | ||||
|     def __init__(self, value: float, var: bool = True, lb: float = -np.inf, ub: float = np.inf): | ||||
|         self.lb = lb if lb is not None else -np.inf | ||||
|         self.ub = ub if ub is not None else np.inf | ||||
|     # TODO Parameter should know its own key | ||||
|     def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf): | ||||
|         self._value: float | None = None | ||||
|         self.var: bool = bool(var) if var is not None else True | ||||
|         self.error:  None | float = None if self.var is False else 0.0 | ||||
|         self.name: str = name | ||||
|         self.function: str = "" | ||||
|  | ||||
|         if self.lb <= value <= self.ub: | ||||
|             self.value = value | ||||
|         self.lb: None | float = lb if lb is not None else -np.inf | ||||
|         self.ub: float | None = ub if ub is not None else np.inf | ||||
|         self.namespace: dict = {} | ||||
|         self.eval_allowed: bool = True | ||||
|         self._expr: None | str = None | ||||
|         self._expr_disp: None | str = None | ||||
|         self.is_global = False | ||||
|  | ||||
|         if isinstance(value, str): | ||||
|             self._expr_disp = value | ||||
|             self._expr = value | ||||
|             self.var = False | ||||
|         else: | ||||
|             print(value, self.lb, self.ub) | ||||
|             raise ValueError('Value of parameter is outside bounds') | ||||
|             if self.lb <= value <= self.ub: | ||||
|                 self._value = value | ||||
|             else: | ||||
|                 raise ValueError('Value of parameter is outside bounds') | ||||
|  | ||||
|         self.init_val = value | ||||
|             self.init_val = value | ||||
|  | ||||
|         with np.errstate(divide='ignore'): | ||||
|             # throws RuntimeWarning for zeros | ||||
|             self.scale = 10**(np.floor(np.log10(np.abs(self.value)))) | ||||
|             with np.errstate(divide='ignore'): | ||||
|                 # throws RuntimeWarning for zeros | ||||
|                 self.scale = 10**(np.floor(np.log10(np.abs(self.value)))) | ||||
|  | ||||
|         if self.scale == 0: | ||||
|             self.scale = 1. | ||||
|             if self.scale == 0: | ||||
|                 self.scale = 1. | ||||
|  | ||||
|         self.var = bool(var) if var is not None else True | ||||
|         self.error = None if self.var is False else 0.0 | ||||
|         self.name = '' | ||||
|         self.function = '' | ||||
|  | ||||
|     def __str__(self): | ||||
|         start = '' | ||||
|     def __str__(self) -> str: | ||||
|         start = StringIO() | ||||
|         if self.name: | ||||
|             if self.function: | ||||
|                 start = f'{self.name} ({self.function}):  ' | ||||
|                 start.write(f"{self.name} ({self.function})") | ||||
|             else: | ||||
|                 start = self.name + ':  ' | ||||
|                 start.write(self.name) | ||||
|  | ||||
|         if self.is_global: | ||||
|             start.write("*") | ||||
|  | ||||
|         start.write(":  ") | ||||
|  | ||||
|         if self.var: | ||||
|             return start + f'{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}' | ||||
|             start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}") | ||||
|         else: | ||||
|             return start + f'{self.value:} (fixed)' | ||||
|             start.write(f"{self.value:.4g}") | ||||
|             if self._expr is None: | ||||
|                 start.write(" (fixed)") | ||||
|             else: | ||||
|                 start.write(f" (calc: {self._expr_disp})") | ||||
|  | ||||
|     def __add__(self, other: Parameter | float) -> float: | ||||
|         return start.getvalue() | ||||
|  | ||||
|     def __add__(self, other: Parameter | float | int) -> float: | ||||
|         if isinstance(other, (float, int)): | ||||
|             return self.value + other | ||||
|         elif isinstance(other, Parameter): | ||||
| @@ -128,30 +174,39 @@ class Parameter: | ||||
|         return self.__add__(other) | ||||
|  | ||||
|     @property | ||||
|     def scaled_value(self): | ||||
|     def scaled_value(self) -> float: | ||||
|         return self.value / self.scale | ||||
|  | ||||
|     @scaled_value.setter | ||||
|     def scaled_value(self, value): | ||||
|         self.value = value * self.scale | ||||
|     def scaled_value(self, value: float) -> None: | ||||
|         self._value = value * self.scale | ||||
|  | ||||
|     @property | ||||
|     def scaled_error(self): | ||||
|         if self.error is None: | ||||
|             return self.error | ||||
|         else: | ||||
|     def value(self) -> float | None: | ||||
|         if self._value is not None: | ||||
|             return self._value | ||||
|  | ||||
|         if self._expr is not None and self.eval_allowed: | ||||
|             return eval(self._expr, {}, self.namespace) | ||||
|  | ||||
|         return | ||||
|  | ||||
|     @property | ||||
|     def scaled_error(self) -> None | float: | ||||
|         if self.error is not None: | ||||
|             return self.error / self.scale | ||||
|  | ||||
|         return | ||||
|  | ||||
|     @scaled_error.setter | ||||
|     def scaled_error(self, value): | ||||
|     def scaled_error(self, value) -> None: | ||||
|         self.error = value * self.scale | ||||
|  | ||||
|     def get_state(self): | ||||
|  | ||||
|         return {slot: getattr(self, slot) for slot in  self.__slots__} | ||||
|     def get_state(self) -> dict: | ||||
|         return {slot: getattr(self, slot) for slot in self.__slots__} | ||||
|  | ||||
|     @staticmethod | ||||
|     def set_state(state: dict): | ||||
|     def set_state(state: dict) -> Parameter: | ||||
|         par = Parameter(state.pop('value')) | ||||
|         for k, v in state.items(): | ||||
|             setattr(par, k, v) | ||||
| @@ -159,9 +214,28 @@ class Parameter: | ||||
|         return par | ||||
|  | ||||
|     @property | ||||
|     def full_name(self): | ||||
|     def full_name(self) -> str: | ||||
|         name = self.name | ||||
|         if self.function: | ||||
|             name += ' (' + self.function + ')' | ||||
|             name += f" ({self.function})" | ||||
|  | ||||
|         return name | ||||
|  | ||||
|     def copy(self) -> Parameter: | ||||
|         if self._expr: | ||||
|             val = self._expr_disp | ||||
|         else: | ||||
|             val = self._value | ||||
|         para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub) | ||||
|         para_copy._expr = self._expr | ||||
|         para_copy.namespace = self.namespace | ||||
|         para_copy.is_global = self.is_global | ||||
|         para_copy.error = self.error | ||||
|         para_copy.function = self.function | ||||
|  | ||||
|         return para_copy | ||||
|  | ||||
|     def fix(self): | ||||
|         self._value = self.value | ||||
|         self.namespace = {} | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ from __future__ import annotations | ||||
|  | ||||
| import re | ||||
| from collections import OrderedDict | ||||
| from io import StringIO | ||||
| from pathlib import Path | ||||
| from typing import Any | ||||
|  | ||||
| @@ -186,7 +187,7 @@ class FitResult(Points): | ||||
|                 nice_name = m.group(1) | ||||
|                 if func_number in split_funcs: | ||||
|                     nice_func = split_funcs[func_number] | ||||
|  | ||||
|             pvalue.fix() | ||||
|             pvalue.name = nice_name | ||||
|             pvalue.function = nice_func | ||||
|             parameter_dic[pname] = pvalue | ||||
| @@ -223,27 +224,30 @@ class FitResult(Points): | ||||
|         return self.nobs-self.nvar | ||||
|  | ||||
|     def pprint(self, statistics=True, correlations=True): | ||||
|         print('Fit result:') | ||||
|         print('  model :', self.name) | ||||
|         print('  #data :', self.nobs) | ||||
|         print('  #var  :', self.nvar) | ||||
|         print('\nParameter') | ||||
|         print(self.parameter_string()) | ||||
|         s = StringIO() | ||||
|         s.write('Fit result:\n') | ||||
|         s.write(f'  model : {self.name}\n') | ||||
|         s.write(f'  #data : {self.nobs}\n') | ||||
|         s.write(f'  #var  : {self.nvar}\n') | ||||
|         s.write('\nParameter\n') | ||||
|         s.write(self.parameter_string()) | ||||
|  | ||||
|         if statistics: | ||||
|             print('Statistics') | ||||
|             s.write('\nStatistics\n') | ||||
|             for k, v in self.statistics.items(): | ||||
|                 print(f'  {k} : {v:.4f}') | ||||
|                 s.write(f'  {k} : {v:.4f}\n') | ||||
|  | ||||
|         if correlations and self.correlation is not None: | ||||
|             print('\nCorrelation (partial corr.)') | ||||
|             print(self._correlation_string()) | ||||
|             print() | ||||
|             s.write('\nCorrelation (partial corr.)\n') | ||||
|             s.write(self._correlation_string()) | ||||
|             s.write('\n') | ||||
|  | ||||
|         print(s.getvalue()) | ||||
|  | ||||
|     def parameter_string(self): | ||||
|         ret_val = '' | ||||
|  | ||||
|         for pval in self.parameter.values(): | ||||
|         for pkey, pval in self.parameter.items(): | ||||
|             ret_val += convert(str(pval), old='tex', new='str') + '\n' | ||||
|  | ||||
|         if self.fun_kwargs: | ||||
| @@ -255,9 +259,7 @@ class FitResult(Points): | ||||
|     def _correlation_string(self): | ||||
|         ret_val = '' | ||||
|         for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list(): | ||||
|             ret_val += '  {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'), | ||||
|                                                               convert(p_j, old='tex', new='str'), | ||||
|                                                               corr_ij, pcorr_ij) | ||||
|             ret_val += f"  {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n" | ||||
|         return ret_val | ||||
|  | ||||
|     def correlation_list(self, limit=0.1): | ||||
|   | ||||
| @@ -35,8 +35,8 @@ class Gaussian: | ||||
| class Lorentzian: | ||||
|     type = 'Spectrum' | ||||
|     name = 'Lorentzian' | ||||
|     equation = 'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}' | ||||
|     params = ['A', '\mu', 'w', 'A_{0}'] | ||||
|     equation = r'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}' | ||||
|     params = ['A', r'\mu', 'w', 'A_{0}'] | ||||
|     ext_params = None | ||||
|     bounds = [(0, None), (None, None), (0, None), (None, None)] | ||||
|  | ||||
| @@ -62,9 +62,9 @@ class Lorentzian: | ||||
| class PseudoVoigt: | ||||
|     type = 'Spectrum' | ||||
|     name = 'Pseudo Voigt' | ||||
|     equation = 'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \ | ||||
|                '(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}' | ||||
|     params = ['A', 'R', '\mu', 'w', 'A_{0}'] | ||||
|     equation = r'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \ | ||||
|                r'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}' | ||||
|     params = ['A', 'R', r'\mu', 'w', 'A_{0}'] | ||||
|     ext_params = None | ||||
|     bounds = [(0, None), (0, 1), (None, None), (0, None)] | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,7 @@ | ||||
|     <x>0</x> | ||||
|     <y>0</y> | ||||
|     <width>365</width> | ||||
|     <height>78</height> | ||||
|     <height>66</height> | ||||
|    </rect> | ||||
|   </property> | ||||
|   <property name="sizePolicy"> | ||||
| @@ -62,7 +62,7 @@ | ||||
|      <item> | ||||
|       <widget class="LineEdit" name="parameter_line"> | ||||
|        <property name="sizePolicy"> | ||||
|         <sizepolicy hsizetype="Fixed" vsizetype="Fixed"> | ||||
|         <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed"> | ||||
|          <horstretch>0</horstretch> | ||||
|          <verstretch>0</verstretch> | ||||
|         </sizepolicy> | ||||
| @@ -78,19 +78,6 @@ | ||||
|        </property> | ||||
|       </widget> | ||||
|      </item> | ||||
|      <item> | ||||
|       <spacer name="horizontalSpacer"> | ||||
|        <property name="orientation"> | ||||
|         <enum>Qt::Horizontal</enum> | ||||
|        </property> | ||||
|        <property name="sizeHint" stdset="0"> | ||||
|         <size> | ||||
|          <width>40</width> | ||||
|          <height>20</height> | ||||
|         </size> | ||||
|        </property> | ||||
|       </spacer> | ||||
|      </item> | ||||
|      <item> | ||||
|       <widget class="QCheckBox" name="fixed_check"> | ||||
|        <property name="text"> | ||||
| @@ -105,19 +92,6 @@ | ||||
|        </property> | ||||
|       </widget> | ||||
|      </item> | ||||
|      <item> | ||||
|       <widget class="QToolButton" name="toolButton"> | ||||
|        <property name="text"> | ||||
|         <string/> | ||||
|        </property> | ||||
|        <property name="popupMode"> | ||||
|         <enum>QToolButton::InstantPopup</enum> | ||||
|        </property> | ||||
|        <property name="arrowType"> | ||||
|         <enum>Qt::RightArrow</enum> | ||||
|        </property> | ||||
|       </widget> | ||||
|      </item> | ||||
|     </layout> | ||||
|    </item> | ||||
|    <item> | ||||
|   | ||||
		Reference in New Issue
	
	Block a user