diff --git a/src/gui_qt/_py/fitmodelwidget.py b/src/gui_qt/_py/fitmodelwidget.py index a41f14b..f183f36 100644 --- a/src/gui_qt/_py/fitmodelwidget.py +++ b/src/gui_qt/_py/fitmodelwidget.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- -# Form implementation generated from reading ui file 'resources/_ui/fitmodelwidget.ui' +# Form implementation generated from reading ui file 'src/resources/_ui/fitmodelwidget.ui' # -# Created by: PyQt5 UI code generator 5.12.3 +# Created by: PyQt5 UI code generator 5.15.9 # -# WARNING! All changes made in this file will be lost! +# WARNING: Any manual changes made to this file will be lost when pyuic5 is +# run again. Do not edit this file unless you know what you are doing. from PyQt5 import QtCore, QtGui, QtWidgets @@ -13,7 +14,7 @@ from PyQt5 import QtCore, QtGui, QtWidgets class Ui_FitParameter(object): def setupUi(self, FitParameter): FitParameter.setObjectName("FitParameter") - FitParameter.resize(365, 78) + FitParameter.resize(365, 66) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -36,7 +37,7 @@ class Ui_FitParameter(object): self.parametername.setObjectName("parametername") self.horizontalLayout_2.addWidget(self.parametername) self.parameter_line = LineEdit(FitParameter) - sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) + sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth()) @@ -44,20 +45,12 @@ class Ui_FitParameter(object): self.parameter_line.setText("") self.parameter_line.setObjectName("parameter_line") self.horizontalLayout_2.addWidget(self.parameter_line) - spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) - self.horizontalLayout_2.addItem(spacerItem) self.fixed_check = QtWidgets.QCheckBox(FitParameter) self.fixed_check.setObjectName("fixed_check") self.horizontalLayout_2.addWidget(self.fixed_check) self.global_checkbox = QtWidgets.QCheckBox(FitParameter) self.global_checkbox.setObjectName("global_checkbox") self.horizontalLayout_2.addWidget(self.global_checkbox) - self.toolButton = QtWidgets.QToolButton(FitParameter) - self.toolButton.setText("") - self.toolButton.setPopupMode(QtWidgets.QToolButton.InstantPopup) - self.toolButton.setArrowType(QtCore.Qt.RightArrow) - self.toolButton.setObjectName("toolButton") - self.horizontalLayout_2.addWidget(self.toolButton) self.verticalLayout.addLayout(self.horizontalLayout_2) self.frame = QtWidgets.QFrame(FitParameter) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum) diff --git a/src/gui_qt/data/container.py b/src/gui_qt/data/container.py index 913ad7e..106eee3 100644 --- a/src/gui_qt/data/container.py +++ b/src/gui_qt/data/container.py @@ -8,6 +8,7 @@ from pyqtgraph import mkPen from nmreval.data.points import Points from nmreval.data.signals import Signal +from nmreval.lib.logger import logger from nmreval.utils.text import convert from nmreval.data.bds import BDS from nmreval.data.dsc import DSC @@ -356,7 +357,7 @@ class ExperimentContainer(QtCore.QObject): elif mode in ['imag', 'all'] and self.plot_imag is not None: self.plot_imag.set_symbol(symbol=symbol, size=size, color=color) else: - print('Updating symbol failed for ' + str(self.id)) + logger.warning(f'Updating symbol failed for {self.id}') def setLine(self, *, width=None, style=None, color=None, mode='real'): if mode in ['real', 'all']: @@ -368,7 +369,7 @@ class ExperimentContainer(QtCore.QObject): elif mode in ['imag', 'all'] and self.plot_imag is not None: self.plot_imag.set_line(width=width, style=style, color=color) else: - print('Updating line failed for ' + str(self.id)) + logger.warning(f'Updating line failed for {self.id}') def update_property(self, key1: str, key2: str, value: Any): keykey = key2.split() diff --git a/src/gui_qt/data/signaledit/editsignalwidget.py b/src/gui_qt/data/signaledit/editsignalwidget.py index 980ea19..31c4aff 100644 --- a/src/gui_qt/data/signaledit/editsignalwidget.py +++ b/src/gui_qt/data/signaledit/editsignalwidget.py @@ -1,3 +1,4 @@ +from nmreval.lib.logger import logger from nmreval.math import apodization from nmreval.lib.importer import find_models from nmreval.utils.text import convert @@ -67,7 +68,7 @@ class EditSignalWidget(QtWidgets.QWidget, Ui_Form): self.do_something.emit(sender, (ph0, ph1, pvt)) else: - print('You should never reach this by accident.') + logger.warning(f'You should never reach this by accident, invalid sender {sender!r}') @QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged') def change_apodization(self, index): diff --git a/src/gui_qt/fit/fit_forms.py b/src/gui_qt/fit/fit_forms.py index db7561d..67f27d2 100644 --- a/src/gui_qt/fit/fit_forms.py +++ b/src/gui_qt/fit/fit_forms.py @@ -19,19 +19,20 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): super().__init__(parent) self.setupUi(self) - self.parametername.setText(label + ' ') + self.name = label + + self.parametername.setText(convert(label) + ' ') - validator = QtGui.QDoubleValidator() - self.parameter_line.setValidator(validator) self.parameter_line.setText('1') - self.parameter_line.setMaximumWidth(240) - self.lineEdit.setMaximumWidth(60) - self.lineEdit_2.setMaximumWidth(60) + self.parameter_line.setMaximumWidth(160) + self.lineEdit.setMaximumWidth(100) + self.lineEdit_2.setMaximumWidth(100) - self.label_3.setText(f'< {label} <') + self.label_3.setText(f'< {convert(label)} <') self.checkBox.stateChanged.connect(self.enableBounds) self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit()) + self.parameter_line.editingFinished.connect(self.update_parameter) self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self)) self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None)) self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text())) @@ -40,18 +41,12 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): if fixed: self.fixed_check.hide() - self.menu = QtWidgets.QMenu(self) - self.add_links() - - self.is_linked = None self.parameter_pos = None self.func_idx = None self._linetext = '1' - @property - def name(self): - return convert(self.parametername.text().strip(), old='html', new='str') + self.menu = QtWidgets.QMenu(self) def set_parameter_string(self, p: str): self.parameter_line.setText(p) @@ -71,38 +66,24 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None, fixed: bool = None, glob: bool = None): - if p is None: - # bad hack: linked parameter return (None, linked parameter) - # if p is None -> parameter is linked to argument given by bds - self.link_parameter(linkto=bds) - else: - ptext = f'{p:.4g}' + ptext = f'{p:.4g}' - self.set_parameter_string(ptext) + self.set_parameter_string(ptext) - if bds is not None: - self.set_bounds(*bds) + if bds is not None: + self.set_bounds(*bds) - if fixed is not None: - self.fixed_check.setCheckState(QtCore.Qt.Unchecked if fixed else QtCore.Qt.Checked) + if fixed is not None: + self.fixed_check.setCheckState(QtCore.Qt.Unchecked if fixed else QtCore.Qt.Checked) - if glob is not None: - self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked) + if glob is not None: + self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked) def get_parameter(self): - if self.is_linked: - try: - p = float(self._linetext) - except ValueError: - p = 1.0 - else: - try: - p = float(self.parameter_line.text().replace(',', '.')) - except ValueError: - _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', - f'{self.parametername.text()} contains invalid values', - QtWidgets.QMessageBox.Cancel) - return None + try: + p = float(self.parameter_line.text().replace(',', '.')) + except ValueError: + p = self.parameter_line.text().replace(',', '.') if self.checkBox.isChecked(): try: @@ -119,75 +100,27 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter): bounds = (lb, rb) - return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked(), self.is_linked + return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked() @QtCore.pyqtSlot(bool) def set_fixed(self, state: bool): # self.global_checkbox.setVisible(not state) self.frame.setVisible(not state) - def add_links(self, parameter: dict = None): - if parameter is None: - parameter = {} - self.menu.clear() - - ac = QtWidgets.QAction('Link to...', self) - ac.triggered.connect(self.link_parameter) - self.menu.addAction(ac) - - for model_key, model_funcs in parameter.items(): - m = QtWidgets.QMenu('Model ' + model_key, self) - for func_name, func_params in model_funcs.items(): - m2 = QtWidgets.QMenu(func_name, m) - for p_name, idx in func_params: - ac = QtWidgets.QAction(p_name, m2) - ac.setData((model_key, *idx)) - ac.triggered.connect(self.link_parameter) - m2.addAction(ac) - m.addMenu(m2) - self.menu.addMenu(m) - - self.toolButton.setMenu(self.menu) - @QtCore.pyqtSlot() - def link_parameter(self, linkto=None): - if linkto is None: - action = self.sender() - else: - action = False - for m in self.menu.actions(): - if m.menu(): - for a in m.menu().actions(): - if a.data() == linkto: - action = a - break - if action: - break - - if (self.func_idx, self.parameter_pos) == action.data(): - return + def update_parameter(self): + new_value = self.parameter_line.text() + if not new_value: + self.parameter_line.setText('1') try: - new_text = f'Linked to {action.parentWidget().title()}.{action.text()}' - self._linetext = self.parameter_line.text() - self.parameter_line.setText(new_text) - self.parameter_line.setEnabled(False) - self.global_checkbox.hide() - self.global_checkbox.blockSignals(True) - self.global_checkbox.setCheckState(QtCore.Qt.Checked) - self.global_checkbox.blockSignals(False) - self.frame.hide() - self.is_linked = action.data() + float(new_value) + is_text = False + except ValueError: + is_text = True + self.global_checkbox.setCheckState(False) - except AttributeError: - self.parameter_line.setText(self._linetext) - self.parameter_line.setEnabled(True) - if self.fixed_check.isEnabled(): - self.global_checkbox.show() - self.frame.show() - self.is_linked = None - - self.state_changed.emit() + self.set_fixed(is_text) class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog): @@ -282,8 +215,17 @@ class FitModelTree(QtWidgets.QTreeWidget): idx = item.data(0, self.counterRole) self.itemRemoved.emit(idx) - def add_function(self, idx: int, cnt: int, op: int, name: str, color: QtGui.QColor | str | tuple, - parent: QtWidgets.QTreeWidgetItem = None, children: list = None, active: bool = True, **kwargs): + def add_function(self, + idx: int, + cnt: int, + op: int, + name: str, + color: QtGui.QColor | str | tuple, + parent: QtWidgets.QTreeWidgetItem = None, + children: list = None, + active: bool = True, + param_names: list[str] = None, + **kwargs): """ Add function to tree and dictionary of functions. """ @@ -298,6 +240,10 @@ class FitModelTree(QtWidgets.QTreeWidget): it.setData(0, self.counterRole, cnt) it.setData(0, self.operatorRole, op) it.setText(0, name) + if param_names is not None: + it.setToolTip(0, + 'Parameter names:\n' + + '\n'.join(f'{pn}({cnt})' for pn in param_names)) it.setForeground(0, QtGui.QBrush(color)) it.setIcon(0, get_icon(self.icons[op])) diff --git a/src/gui_qt/fit/fit_parameter.py b/src/gui_qt/fit/fit_parameter.py index 6c3ec60..30b4b31 100644 --- a/src/gui_qt/fit/fit_parameter.py +++ b/src/gui_qt/fit/fit_parameter.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Optional + +from nmreval.fit.parameter import Parameter from nmreval.utils.text import convert from ..Qt import QtWidgets, QtCore, QtGui @@ -62,8 +65,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): self.glob_values = [1] * len(func.params) for k, v in enumerate(func.params): - name = convert(v) - widgt = FitModelWidget(label=name, parent=self.scrollwidget) + widgt = FitModelWidget(label=v, parent=self.scrollwidget) widgt.parameter_pos = k widgt.func_idx = idx try: @@ -83,7 +85,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): self.global_parameter.append(widgt) self.scrollwidget.layout().addWidget(widgt) - widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2) + widgt2 = ParameterSingleWidget(name=v, parent=self.scrollwidget2) widgt2.valueChanged.connect(self.change_single_parameter) widgt2.removeSingleValue.connect(self.change_single_parameter) widgt2.installEventFilter(self) @@ -115,20 +117,22 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): self.scrollwidget.layout().addStretch(1) self.scrollwidget2.layout().addStretch(1) - def set_links(self, parameter): - for w in self.global_parameter: - if isinstance(w, FitModelWidget): - w.add_links(parameter) + # def set_links(self, parameter): + # for w in self.global_parameter: + # if isinstance(w, FitModelWidget): + # w.add_links(parameter) @QtCore.pyqtSlot(str) def change_global_parameter(self, value: str, idx: int = None): if idx is None: idx = self.global_parameter.index(self.sender()) - self.glob_values[idx] = float(value) + # self.glob_values[idx] = float(value) + self.glob_values[idx] = value if self.data_values[self.comboBox.currentData()][idx] is None: self.data_parameter[idx].blockSignals(True) - self.data_parameter[idx].value = float(value) + # self.data_parameter[idx].value = float(value) + self.data_parameter[idx].value = value self.data_parameter[idx].blockSignals(False) @QtCore.pyqtSlot(str, object) @@ -171,7 +175,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): # disable single parameter if it is set global, enable if global is unset widget = self.sender() idx = self.global_parameter.index(widget) - enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None) + enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) self.data_parameter[idx].setEnabled(enable) def select_next_preview(self, direction): @@ -204,64 +208,50 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): if sid not in self.data_values: self.data_values[sid] = [None] * len(self.data_parameter) - def get_parameter(self, use_func=None): + def get_parameter(self, use_func=None) -> tuple[dict, list]: bds = [] is_global = [] is_fixed = [] - globs = [] - is_linked = [] + param_general = [] for g in self.global_parameter: if isinstance(g, FitModelWidget): - p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter() + p_i, bds_i, fixed_i, global_i = g.get_parameter() + parameter_i = Parameter(name=g.name, value=p_i, lb=bds_i[0], ub=bds_i[1], var=fixed_i) + param_general.append(parameter_i) - globs.append(p_i) bds.append(bds_i) is_fixed.append(fixed_i) is_global.append(global_i) - is_linked.append(link_i) - - lb, ub = list(zip(*bds)) data_parameter = {} if use_func is None: use_func = list(self.data_values.keys()) - global_p = None for sid, parameter in self.data_values.items(): if sid not in use_func: continue kw_p = {} p = [] - if global_p is None: - global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []} for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)): if isinstance(g, FitModelWidget): if (p_i is None) or is_global[i]: - p.append(globs[i]) - if is_global[i]: - if i not in global_p['idx']: - global_p['p'].append(globs[i]) - global_p['idx'].append(i) - global_p['var'].append(is_fixed[i]) - global_p['ub'].append(ub[i]) - global_p['lb'].append(lb[i]) + # set has no oen value + p.append(param_general[i].copy()) else: - p.append(p_i) + lb, ub = bds[i] + try: + if not (lb < p_i < ub): + raise ValueError(f'Parameter {g.name} is outside bounds ({lb}, {ub})') + except TypeError: + pass - try: - if p[i] > ub[i]: - raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') - except TypeError: - pass - - try: - if p[i] < lb[i]: - raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') - except TypeError: - pass + # create Parameter + p.append( + Parameter(name=g.name, value=p_i, lb=lb, ub=ub, var=is_fixed[i]) + ) else: if p_i is None: @@ -273,7 +263,15 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit): data_parameter[sid] = (p, kw_p) - return data_parameter, lb, ub, is_fixed, global_p, is_linked + global_parameter = [] + for param, global_flag in zip(param_general, is_global): + if global_flag: + param.is_global = True + global_parameter.append(param) + else: + global_parameter.append(None) + + return data_parameter, global_parameter def set_parameter(self, set_id: str | None, parameter: list[float]) -> int: num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter)) @@ -304,12 +302,12 @@ class ParameterSingleWidget(QtWidgets.QWidget): self._init_ui() - self._name = name + self.name = name self.label.setText(convert(name)) self.label.setToolTip('If this is bold then this parameter is only for this data. ' 'Otherwise, the general parameter is used and displayed') - self.value_line.setValidator(QtGui.QDoubleValidator()) + # self.value_line.setValidator(QtGui.QDoubleValidator()) self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0) self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit()) @@ -343,9 +341,10 @@ class ParameterSingleWidget(QtWidgets.QWidget): @value.setter def value(self, val): - self.value_line.setText(f'{float(val):.5g}') + # self.value_line.setText(f'{float(val):.5g}') + self.value_line.setText(f'{val}') - def show_as_local_parameter(self, is_local): + def show_as_local_parameter(self, is_local: bool): if is_local: self.label.setStyleSheet('font-weight: bold;') else: diff --git a/src/gui_qt/fit/fitfunction.py b/src/gui_qt/fit/fitfunction.py index 0d1c412..2b74641 100644 --- a/src/gui_qt/fit/fitfunction.py +++ b/src/gui_qt/fit/fitfunction.py @@ -128,7 +128,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form): self.newFunction.emit(idx, cnt) - self.add_function(idx, cnt, op, name, col) + self.add_function(idx, cnt, op, name, col, param_names=self.functions[idx].params) def add_function(self, idx: int, cnt: int, op: int, name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs): @@ -141,6 +141,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form): qcolor = QtGui.QColor.fromRgbF(*color) else: qcolor = QtGui.QColor(color) + self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs) f = self.functions[idx] diff --git a/src/gui_qt/fit/fitwindow.py b/src/gui_qt/fit/fitwindow.py index 519216e..9888685 100644 --- a/src/gui_qt/fit/fitwindow.py +++ b/src/gui_qt/fit/fitwindow.py @@ -9,6 +9,9 @@ import numpy as np from pyqtgraph import mkPen from nmreval.fit._meta import MultiModel, ModelFactory +from nmreval.fit.data import Data +from nmreval.fit.model import Model +from nmreval.fit.parameter import Parameters from nmreval.fit.result import FitResult from .fit_forms import FitTableWidget @@ -116,7 +119,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): # collect parameter names etc. to allow linkage self._func_list[self._current_model] = self.functionwidget.get_parameter_list() - dialog.set_links(self._func_list) + # dialog.set_links(self._func_list) # show same tab (general parameter/Data parameter) tab_idx = 0 @@ -219,57 +222,49 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): def _prepare(self, model: list, function_use: list = None, parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: + if parameter is None: - parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [], - 'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []}, - 'links': [], 'color': []} + parameter = { + 'data_parameter': {}, + 'global_parameter': [], + 'links': [], + 'color': [], + } for i, f in enumerate(model): if not f['active']: continue try: - p, lb, ub, var, glob, links = self.param_widgets[f['cnt']].get_parameter(function_use) + p, glob = self.param_widgets[f['cnt']].get_parameter(function_use) except ValueError as e: _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e), QtWidgets.QMessageBox.Ok) return {}, -1 - p_len = len(parameter['lb']) - - parameter['lb'] += lb - parameter['ub'] += ub - parameter['var'] += var - parameter['links'] += links - parameter['color'] += [f['color']] + parameter['color'].append(f['color']) + parameter['global_parameter'].extend(glob) cnt = f['cnt'] - for p_k, v_k in p.items(): if add_idx: kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} else: kw_k = v_k[1] - if p_k in parameter['parameter']: - params, kw = parameter['parameter'][p_k] + if p_k in parameter['data_parameter']: + params, kw = parameter['data_parameter'][p_k] params += v_k[0] kw.update(kw_k) else: - parameter['parameter'][p_k] = (v_k[0], kw_k) - - for g_k, g_v in glob.items(): - if g_k != 'idx': - parameter['glob'][g_k] += g_v - else: - parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v] + parameter['data_parameter'][p_k] = (v_k[0], kw_k) if add_idx: cnt += 1 if f['children']: # recurse for children - child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) + _, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) return parameter, cnt @@ -280,30 +275,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): data = self.data_table.collect_data(default=self.default_combobox.currentData()) func_dict = {} - for k, mod in self.models.items(): - func, order, param_len = ModelFactory.create_from_list(mod) + for model_name, model_parameter in self.models.items(): + func, order, param_len = ModelFactory.create_from_list(model_parameter) if func is None: continue - if k in data: - parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel)) + func = Model(func) + + if model_name in data: + parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel)) + if parameter is None: return + for (data_parameter, _) in parameter['data_parameter'].values(): + for pname, param in zip(func.params, data_parameter): + param.name = pname + + if self._complex[model_name] is not None: + for p_k, p_v in parameter['data_parameter'].items(): + p_v[1].update({'complex_mode': self._complex[model_name]}) + parameter['data_parameter'][p_k] = p_v[0], p_v[1] + + for pname, param_value in zip(func.params, parameter['global_parameter']): + if param_value is not None: + param_value.name = pname + func.set_global_parameter(param_value) + parameter['func'] = func parameter['order'] = order parameter['len'] = param_len - parameter['complex'] = self._complex[k] - if self._complex[k] is not None: - for p_k, p_v in parameter['parameter'].items(): - p_v[1].update({'complex_mode': self._complex[k]}) - parameter['parameter'][p_k] = p_v[0], p_v[1] + parameter['complex'] = self._complex[model_name] - func_dict[k] = parameter + func_dict[model_name] = parameter replaceable = [] - for k, v in func_dict.items(): + for model_name, v in func_dict.items(): for i, link_i in enumerate(v['links']): if link_i is None: continue @@ -334,7 +342,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): QtWidgets.QMessageBox.Ok) return - replaceable.append((k, i, rep_model, repl_idx)) + replaceable.append((model_name, i, rep_model, repl_idx)) replace_value = None for p_k in f['parameter'].values(): @@ -412,31 +420,37 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): def make_previews(self, x, models_parameters: dict): self.preview_lines = [] + # needed to create namespace + param_dict = Parameters() + + cnt = 0 + for model in models_parameters.values(): + f = model['func'] + for parameter_list in model['data_parameter'].values(): + for i, p_value in enumerate(parameter_list[0]): + p_value.name = f.params[i] + param_dict.add_parameter(f'a{cnt}', p_value) + cnt += 1 + for k, model in models_parameters.items(): f = model['func'] is_complex = self._complex[k] - parameters = model['parameter'] + parameters = model['data_parameter'] color = model['color'] - seen_parameter = [] - for p, kwargs in parameters.values(): - if (p, kwargs) in seen_parameter: - # plot only previews with different parameter - continue - - seen_parameter.append((p, kwargs)) + p_value = [pp.value for pp in p] if is_complex is not None: - y = f.func(x, *p, complex_mode=is_complex, **kwargs) + y = f.func(x, *p_value, complex_mode=is_complex, **kwargs) if np.iscomplexobj(y): self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) else: self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) else: - y = f.func(x, *p, **kwargs) + y = f.func(x, *p_value, **kwargs) self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) if isinstance(f, MultiModel): @@ -444,7 +458,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): if is_complex is not None: sub_kwargs.update({'complex_mode': is_complex}) - for i, s in enumerate(f.subs(x, *p, **sub_kwargs)): + for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)): pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i])) if np.iscomplexobj(s): self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i)) @@ -452,15 +466,17 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): else: self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) + param_dict.clear() + return self.preview_lines def set_parameter(self, parameter: dict[str, FitResult]): # which data uses which model data = self.data_table.collect_data(default=self.default_combobox.currentData()) - glob_fit_parameter = [] - for fitted_model, fitted_data in data.items(): + glob_fit_parameter = [] + for fit_id, fit_curve in parameter.items(): if fit_id in fitted_data: fit_parameter = list(fit_curve.parameter.values()) diff --git a/src/gui_qt/graphs/drawings.py b/src/gui_qt/graphs/drawings.py index cd62db4..e87af0e 100644 --- a/src/gui_qt/graphs/drawings.py +++ b/src/gui_qt/graphs/drawings.py @@ -138,9 +138,7 @@ class DrawingsWidget(QtWidgets.QWidget, Ui_Form): graph_id = self.graph_comboBox.currentData() current_lines = self.lines[graph_id] - print(remove_rows) for i in reversed(remove_rows): - print(i) self.tableWidget.removeRow(i) self.line_deleted.emit(current_lines[i], graph_id) diff --git a/src/gui_qt/lib/mdiarea.py b/src/gui_qt/lib/mdiarea.py index 37f6c27..3dd6bbf 100644 --- a/src/gui_qt/lib/mdiarea.py +++ b/src/gui_qt/lib/mdiarea.py @@ -27,7 +27,6 @@ class MdiAreaTile(QtWidgets.QMdiArea): pos = QtCore.QPoint(0, 0) for win in window_list: - print(win.minimumSize()) win.setGeometry(rect) win.move(pos) diff --git a/src/gui_qt/lib/randpok.py b/src/gui_qt/lib/randpok.py deleted file mode 100644 index 1cbfcf8..0000000 --- a/src/gui_qt/lib/randpok.py +++ /dev/null @@ -1,110 +0,0 @@ -import os.path -import json -import urllib.request -import webbrowser -import random - -from ..Qt import QtGui, QtCore, QtWidgets -from .._py.pokemon import Ui_Dialog - -random.seed() - - -class QPokemon(QtWidgets.QDialog, Ui_Dialog): - def __init__(self, number=None, parent=None): - super().__init__(parent=parent) - self.setupUi(self) - self._js = json.load(open(os.path.join(path_to_module, 'utils', 'pokemon.json'), 'r'), encoding='UTF-8') - self._id = 0 - - if number is not None and number in range(1, len(self._js)+1): - poke_nr = f'{number:03d}' - self._id = number - else: - poke_nr = f'{random.randint(1, len(self._js)):03d}' - self._id = int(poke_nr) - - self._pokemon = None - self.show_pokemon(poke_nr) - self.label_15.linkActivated.connect(lambda x: webbrowser.open(x)) - - self.buttonBox.clicked.connect(self.randomize) - self.next_button.clicked.connect(self.show_next) - self.prev_button.clicked.connect(self.show_prev) - - def show_pokemon(self, nr): - self._pokemon = self._js[nr] - self.setWindowTitle('Pokémon: ' + self._pokemon['Deutsch']) - - for i in range(self.tabWidget.count(), -1, -1): - print('i', self.tabWidget.count(), i) - try: - self.tabWidget.widget(i).deleteLater() - except AttributeError: - pass - - for n, img in self._pokemon['Bilder']: - w = QtWidgets.QWidget() - vl = QtWidgets.QVBoxLayout() - l = QtWidgets.QLabel(self) - l.setAlignment(QtCore.Qt.AlignHCenter) - pixmap = QtGui.QPixmap() - - try: - pixmap.loadFromData(urllib.request.urlopen(img, timeout=0.5).read()) - except IOError: - l.setText(n) - else: - sc_pixmap = pixmap.scaled(256, 256, QtCore.Qt.KeepAspectRatio) - l.setPixmap(sc_pixmap) - - vl.addWidget(l) - w.setLayout(vl) - self.tabWidget.addTab(w, n) - - if len(self._pokemon['Bilder']) <= 1: - self.tabWidget.tabBar().setVisible(False) - else: - self.tabWidget.tabBar().setVisible(True) - self.tabWidget.adjustSize() - - self.name.clear() - keys = ['National-Dex', 'Kategorie', 'Typ', 'Größe', 'Gewicht', 'Farbe', 'Link'] - label_list = [self.pokedex_nr, self.category, self.poketype, self.weight, self.height, self.color, self.info] - for (k, label) in zip(keys, label_list): - v = self._pokemon[k] - if isinstance(v, list): - v = os.path.join('', *v) - - if k == 'Link': - v = '{}'.format(v, v) - - label.setText(v) - - for k in ['Deutsch', 'Japanisch', 'Englisch', 'Französisch']: - v = self._pokemon[k] - self.name.addItem(k + ': ' + v) - - self.adjustSize() - - def randomize(self, idd): - if idd.text() == 'Retry': - new_number = f'{random.randint(1, len(self._js)):03d}' - self._id = int(new_number) - self.show_pokemon(new_number) - else: - self.close() - - def show_next(self): - new_number = self._id + 1 - if new_number > len(self._js): - new_number = 1 - self._id = new_number - self.show_pokemon(f'{new_number:03d}') - - def show_prev(self): - new_number = self._id - 1 - if new_number == 0: - new_number = len(self._js) - self._id = new_number - self.show_pokemon(f'{new_number:03d}') diff --git a/src/gui_qt/main/management.py b/src/gui_qt/main/management.py index ef2c4bc..db28c00 100644 --- a/src/gui_qt/main/management.py +++ b/src/gui_qt/main/management.py @@ -441,7 +441,7 @@ class UpperManagement(QtCore.QObject): # all-encompassing error catch try: for model_id, model_p in parameter.items(): - m = Model(model_p['func']) + m = model_p['func'] models[model_id] = m m_complex = model_p['complex'] @@ -450,13 +450,16 @@ class UpperManagement(QtCore.QObject): # iterate over order of set id in active order and access parameter inside loop # instead of directly looping try: - list_ids = list(model_p['parameter'].keys()) + list_ids = list(model_p['data_parameter'].keys()) set_order = [self.active_id.index(i) for i in list_ids] except ValueError as e: raise Exception('Getting order failed') from e for pos in set_order: set_id = list_ids[pos] + + data_i = self.data[set_id] + set_params = model_p['data_parameter'][set_id] try: data_i = self.data[set_id] except KeyError as e: @@ -488,7 +491,7 @@ class UpperManagement(QtCore.QObject): inside = np.where((_x >= x_lim[0]) & (_x <= x_lim[1])) else: inside = np.where((_x >= fit_limits[0]) & (_x <= fit_limits[1])) - + try: if isinstance(we, str): d = fit_d.Data(_x[inside], _y[inside], we=we, idx=set_id) @@ -499,18 +502,12 @@ class UpperManagement(QtCore.QObject): d.set_model(m) try: - d.set_parameter(set_params[0], var=model_p['var'], - lb=model_p['lb'], ub=model_p['ub'], - fun_kwargs=set_params[1]) + d.set_parameter(set_params[0], fun_kwargs=set_params[1]) except Exception as e: raise Exception('Setting parameter failed') from e self.fitter.add_data(d) - model_globs = model_p['glob'] - if model_globs: - m.set_global_parameter(**model_p['glob']) - for links_i in links: self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), (models[links_i[2]], links_i[3])) @@ -1170,7 +1167,6 @@ class UpperManagement(QtCore.QObject): @QtCore.pyqtSlot(dict) def calc_relaxation(self, opts: dict): - params = opts['pts'] if len(params) == 4: if params[3]: diff --git a/src/nmreval/fit/data.py b/src/nmreval/fit/data.py index 42b95bb..4a34409 100644 --- a/src/nmreval/fit/data.py +++ b/src/nmreval/fit/data.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import numpy as np from .model import Model -from .parameter import Parameters +from .parameter import Parameters, Parameter class Data(object): @@ -16,7 +18,7 @@ class Data(object): self.model = None self.minimizer = None self.parameter = Parameters() - self.para_keys = None + self.para_keys: list = [] self.fun_kwargs = {} def __len__(self): @@ -68,12 +70,19 @@ class Data(object): def get_model(self): return self.model - def set_parameter(self, parameter, var=None, ub=None, lb=None, - default_bounds=False, fun_kwargs=None): + def set_parameter(self, + values: list[float | Parameter], + *, + var: list[bool] = None, + ub: list[float] = None, + lb: list[float] = None, + default_bounds: bool = False, + fun_kwargs: dict = None + ): """ Creates parameter for this data. If no Model is available, it falls back to the model - :param parameter: list of parameters + :param values: list of parameters :param var: list of boolean or boolean; False fixes parameter at given list index. Single value is broadcast to all parameter :param ub: list of upper boundaries or float; Single value is broadcast to all parameter. @@ -87,23 +96,46 @@ class Data(object): model = self.model if model is None: # Data has no unique - if self.minimizer is None: - model = None - else: + if self.minimizer is not None: model = self.minimizer.fit_model - self.fun_kwargs.update(model.fun_kwargs) if model is None: raise ValueError('No model found, please set model before parameters') - if default_bounds: + if len(values) != len(model.params): + raise ValueError('Number of given parameter does not match number of model parameters') + + is_parameter = [isinstance(v, Parameter) for v in values] + if all(is_parameter): + for p_i in values: + key = f"p{next(Parameters.parameter_counter)}" + self.parameter.add_parameter(key, p_i) + elif any(is_parameter): + raise ValueError('list of parameter are not all float of Parameter') + + else: + if var is None: + var = [True] * len(values) + if lb is None: - lb = model.lb + if default_bounds: + lb = model.lb + else: + lb = [None] * len(values) + if ub is None: - ub = model.ub + if default_bounds: + ub = model.ub + else: + ub = [None] * len(values) - self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub) + arg_names = ['name', 'value', 'var', 'lb', 'ub'] + for parameter_arg in zip(model.params, values, var, lb, ub): + self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)}) + self.para_keys = list(self.parameter.keys()) + + self.fun_kwargs.update(model.fun_kwargs) if fun_kwargs is not None: self.fun_kwargs.update(fun_kwargs) @@ -123,6 +155,18 @@ class Data(object): else: return [p.value for p in self.minimizer.parameters[self.parameter]] + def replace_parameter(self, key: str, parameter: Parameter) -> None: + tobereplaced = None + for k, v in self.parameter.items(): + if v.name == parameter.name: + tobereplaced = k + break + + if tobereplaced is None: + raise KeyError(f'Global parameter {key} not found in list of parameters') + self.para_keys[self.para_keys.index(tobereplaced)] = key + self.parameter.replace_parameter(tobereplaced, key, parameter) + def cost(self, p): """ Cost function :math:`y-f(p, x)` diff --git a/src/nmreval/fit/minimizer.py b/src/nmreval/fit/minimizer.py index d66db74..5594fd1 100644 --- a/src/nmreval/fit/minimizer.py +++ b/src/nmreval/fit/minimizer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from itertools import product @@ -21,13 +23,70 @@ class FitAbortException(Exception): pass +# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR) +def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_pars: list[list[str]]): + # replace values + for keys, values in zip(varpars, p): + for data_i in data: + if keys in data_i.parameter.keys(): + # TODO move this to scaled_value setter + data_i.parameter[keys].scaled_value = values + data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value + r = [] + # unpack parameter and calculate y values and concatenate all + for values, p_idx in zip(data, used_pars): + actual_parameters = [values.parameter[keys].value for keys in p_idx] + r = np.r_[r, values.cost(actual_parameters)] + + return r + + +def _cost_scipy(p, data, varpars, used_pars): + for keys, values in zip(varpars, p): + data.parameter[keys].scaled_value = values + data.parameter[keys].namespace[keys] = data.parameter[keys].value + + actual_parameters = [data.parameter[keys].value for keys in used_pars] + return data.cost(actual_parameters) + + +def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0): + for keys, values in zip(varpars, p): + data.parameter[keys].scaled_value = values + data.parameter[keys].namespace[keys] = data.parameter[keys].value + + actual_parameters = [data.parameter[keys].value for keys in used_pars] + + return data.func(actual_parameters, data.x) + + +def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]): + # replace values + for data_i in data: + _update_parameter(data_i, var_pars, p) + + r = [] + # unpack parameter and calculate y values and concatenate all + for values, p_idx in zip(data, used_pars): + actual_parameters = [values.parameter[keys].value for keys in p_idx] + r = np.r_[r, values.func(actual_parameters, values.x)] + + return r + + +def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]): + for keys, values in zip(varied_keys, parameter): + if keys in data.parameter.keys(): + data.parameter[keys].scaled_value = values + data.parameter[keys].namespace[keys] = data.parameter[keys].value + + class FitRoutine(object): def __init__(self, mode='lsq'): self.fitmethod = mode self.data = [] self.fit_model = None self._no_own_model = [] - self.parameter = Parameters() self.result = [] self.linked = [] self._abort = False @@ -81,29 +140,27 @@ class FitRoutine(object): return self.fit_model - def set_link_parameter(self, parameter: tuple, replacement: tuple): + def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]): if isinstance(replacement[0], Model): - if replacement[1] not in replacement[0].global_parameter: - raise KeyError(f'Parameter at pos {replacement[1]} of ' - f'model {str(replacement[0])} is not global') + if replacement[1] not in replacement[0].parameter: + raise KeyError(f'Parameter {replacement[1]} of ' + f'model {replacement[0]} is not global') - if isinstance(parameter[0], Model): - warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} ' + if isinstance(dismissed_param[0], Model): + warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} ' f'becomes global with linkage.') - self.linked.append((*parameter, *replacement)) + self.linked.append((*dismissed_param, *replacement)) def prepare_links(self): self._no_own_model = [] - self.parameter = Parameters() _found_models = {} linked_sender = {} for v in self.data: linked_sender[v] = set() - self.parameter.update(v.parameter.copy()) - # set temporaray model + # set temporary model if v.model is None: v.model = self.fit_model self._no_own_model.append(v) @@ -111,8 +168,6 @@ class FitRoutine(object): # register model if v.model not in _found_models: _found_models[v.model] = [] - m_param = v.model.parameter.copy() - self.parameter.update(m_param) _found_models[v.model].append(v) @@ -120,24 +175,21 @@ class FitRoutine(object): linked_sender[v.model] = set() linked_parameter = {} - for par, par_parm, repl, repl_par in self.linked: - if isinstance(par, Data): - if isinstance(repl, Data): - linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par] - else: - linked_parameter[par.para_keys[par_parm]] = repl.global_parameter[repl_par] + for dismiss_model, dismiss_param, replace_model, replace_param in self.linked: + linked_sender[replace_model].add(dismiss_model) + linked_sender[replace_model].add(replace_model) + replace_key = replace_model.parameter.get_key(replace_param) + dismiss_key = dismiss_model.parameter.get_key(dismiss_param) + + if isinstance(replace_model, Data): + linked_parameter[dismiss_key] = replace_key else: - if isinstance(repl, Data): - par.global_parameter[par_parm] = repl.para_keys[repl_par] - else: - par.global_parameter[par_parm] = repl.global_parameter[repl_par] - - linked_sender[repl].add(par) - linked_sender[par].add(repl) + p = dismiss_model.set_global_parameter(dismiss_param, replace_key) + p._expr_disp = replace_param for mm, m_data in _found_models.items(): - if mm.global_parameter: + if mm.parameter: for dd in m_data: linked_sender[mm].add(dd) linked_sender[dd].add(mm) @@ -169,15 +221,13 @@ class FitRoutine(object): logger.info('Fit aborted by user') self._abort = True - def run(self, mode: str=None): + def run(self, mode: str = None): self._abort = False - self.parameter = Parameters() if mode is None: mode = self.fitmethod fit_groups, linked_parameter = self.prepare_links() - for data_groups in fit_groups: if len(data_groups) == 1 and not self.linked: data = data_groups[0] @@ -208,8 +258,21 @@ class FitRoutine(object): self.unprep_run() + for r in self.result: + r.pprint() + return self.result + def make_preview(self, x: np.ndarray) -> list[np.ndarray]: + y_pred = [] + fit_groups, linked_parameter = self.prepare_links() + for data_groups in fit_groups: + data = data_groups[0] + actual_parameters = [p.value for p in data.parameter.values()] + y_pred.append(data.func(actual_parameters, x)) + + return y_pred + def _prep_data(self, data): if data.get_model() is None: data._model = self.fit_model @@ -237,22 +300,16 @@ class FitRoutine(object): var = [] data_pars = [] - # loopyloop over data that belong to one fit (linked or global) + # loopy-loop over data that belong to one fit (linked or global) for data in data_group: + # is parameter replaced by global parameter? + for k, v in data.model.parameter.items(): + data.replace_parameter(k, v) + actual_pars = [] - for i, (p_k, v_k) in enumerate(data.parameter.items()): + for i, p_k in enumerate(data.para_keys): p_k_used = p_k - v_k_used = v_k - - # is parameter replaced by global parameter? - if i in data.model.global_parameter: - p_k_used = data.model.global_parameter[i] - v_k_used = self.parameter[p_k_used] - - # links trump global parameter - if p_k_used in linked: - p_k_used = linked[p_k_used] - v_k_used = self.parameter[p_k_used] + v_k_used = data.parameter[p_k] actual_pars.append(p_k_used) # parameter is variable and was not found before as shared parameter @@ -271,48 +328,7 @@ class FitRoutine(object): d._model = None self._no_own_model = [] - - # COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR) - def __cost_scipy(self, p, data, varpars, used_pars): - for keys, values in zip(varpars, p): - self.parameter[keys].scaled_value = values - - actual_parameters = [self.parameter[keys].value for keys in used_pars] - return data.cost(actual_parameters) - - def __cost_odr(self, p, data, varpars, used_pars): - for keys, values in zip(varpars, p): - self.parameter[keys].scaled_value = values - - actual_parameters = [self.parameter[keys].value for keys in used_pars] - - return data.func(actual_parameters, data.x) - - def __cost_scipy_glob(self, p, data, varpars, used_pars): - # replace values - for keys, values in zip(varpars, p): - self.parameter[keys].scaled_value = values - - r = [] - # unpack parameter and calculate y values and concatenate all - for values, p_idx in zip(data, used_pars): - actual_parameters = [self.parameter[keys].value for keys in p_idx] - r = np.r_[r, values.cost(actual_parameters)] - - return r - - def __cost_odr_glob(self, p, data, varpars, used_pars): - # replace values - for keys, values in zip(varpars, p): - self.parameter[keys].scaled_value = values - - r = [] - # unpack parameter and calculate y values and concatenate all - for values, p_idx in zip(data, used_pars): - actual_parameters = [self.parameter[keys].value for keys in p_idx] - r = np.r_[r, values.func(actual_parameters, values.x)] - - return r + Parameters.reset() def _least_squares_single(self, data, p0, lb, ub, var): self.step = 0 @@ -322,7 +338,7 @@ class FitRoutine(object): if self._abort: raise FitAbortException(f'Fit aborted by user') - return self.__cost_scipy(p, data, var, data.para_keys) + return _cost_scipy(p, data, var, data.para_keys) with np.errstate(all='ignore'): res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) @@ -336,7 +352,7 @@ class FitRoutine(object): self.step += 1 if self._abort: raise FitAbortException(f'Fit aborted by user') - return self.__cost_scipy_glob(p, data, var, data_pars) + return _cost_scipy_glob(p, data, var, data_pars) with np.errstate(all='ignore'): res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) @@ -351,7 +367,7 @@ class FitRoutine(object): self.step += 1 if self._abort: raise FitAbortException(f'Fit aborted by user') - return (self.__cost_scipy(p, data, var, data.para_keys)**2).sum() + return (_cost_scipy(p, data, var, data.para_keys) ** 2).sum() with np.errstate(all='ignore'): res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], @@ -364,7 +380,7 @@ class FitRoutine(object): self.step += 1 if self._abort: raise FitAbortException(f'Fit aborted by user') - return (self.__cost_scipy_glob(p, data, var, data_pars)**2).sum() + return (_cost_scipy_glob(p, data, var, data_pars) ** 2).sum() with np.errstate(all='ignore'): res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], @@ -380,13 +396,18 @@ class FitRoutine(object): self.step += 1 if self._abort: raise FitAbortException(f'Fit aborted by user') - return self.__cost_odr(p, data, var_pars, data.para_keys) + return _cost_odr(p, data, var_pars, data.para_keys) odr_model = odr.Model(func) + corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0) + + self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)), + err=res.sd_beta, corr=corr, partial_corr=partial_corr) + + def _odr_fit(self, odr_data, odr_model, p0): o = odr.ODR(odr_data, odr_model, beta0=p0) res = o.run() - corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var try: corr_inv = np.linalg.inv(corr) @@ -395,16 +416,14 @@ class FitRoutine(object): partial_corr[np.diag_indices_from(partial_corr)] = 1. except np.linalg.LinAlgError: partial_corr = corr - - self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)), - err=res.sd_beta, corr=corr, partial_corr=partial_corr) + return corr, partial_corr, res def _odr_global(self, data, p0, var, data_pars): def func(p, _): self.step += 1 if self._abort: raise FitAbortException(f'Fit aborted by user') - return self.__cost_odr_glob(p, data, var, data_pars) + return _cost_odr_glob(p, data, var, data_pars) x = [] y = [] @@ -415,17 +434,7 @@ class FitRoutine(object): odr_data = odr.Data(x, y) odr_model = odr.Model(func) - o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var) - res = o.run() - - corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var - try: - corr_inv = np.linalg.inv(corr) - corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv))) - partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix - partial_corr[np.diag_indices_from(partial_corr)] = 1. - except np.linalg.LinAlgError: - partial_corr = corr + corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0) for v, var_pars_k in zip(data, data_pars): self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)), @@ -439,15 +448,17 @@ class FitRoutine(object): # update parameter values for keys, p_value, err_value in zip(var_pars, p, err): - self.parameter[keys].scaled_value = p_value - self.parameter[keys].scaled_error = err_value + if keys in data.parameter.keys(): + data.parameter[keys].scaled_value = p_value + data.parameter[keys].scaled_error = err_value + data.parameter[keys].namespace[keys] = data.parameter[keys].value combinations = list(product(var_pars, var_pars)) actual_parameters = [] corr_idx = [] for i, p_i in enumerate(used_pars): - actual_parameters.append(self.parameter[p_i]) + actual_parameters.append(data.parameter[p_i]) for j, p_j in enumerate(used_pars): try: # find the position of the parameter combinations @@ -508,3 +519,4 @@ class FitRoutine(object): partial_corr = corr return _err, corr, partial_corr + diff --git a/src/nmreval/fit/model.py b/src/nmreval/fit/model.py index c0121da..c80a2a3 100644 --- a/src/nmreval/fit/model.py +++ b/src/nmreval/fit/model.py @@ -6,7 +6,7 @@ from typing import Sized from numpy import inf from ._meta import MultiModel -from .parameter import Parameters +from .parameter import Parameters, Parameter class Model(object): @@ -25,7 +25,6 @@ class Model(object): self.ub = [i if i is not None else inf for i in self.ub] self.parameter = Parameters() - self.global_parameter = {} self.is_complex = None self._complex_part = False @@ -80,23 +79,33 @@ class Model(object): self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items() if v.default is not inspect.Parameter.empty} - def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False): - if idx is None: - self.parameter = Parameters() - self.global_parameter = {} - return + def set_global_parameter(self, + key: str | Parameter, + value: float | str = None, + *, + var: bool = None, + lb: float = None, + ub: float = None, + default_bounds: bool = False, + ) -> Parameter: - if default_bounds: - if lb is None: - lb = [self.lb[i] for i in idx] - if ub is None: - ub = [self.lb[i] for i in idx] + if isinstance(key, Parameter): + p = key + key = f'p{next(Parameters.parameter_counter)}' + self.parameter.add_parameter(key, p) - gp = self.parameter.add_parameter(p, var=var, lb=lb, ub=ub) - for k, v in zip(idx, gp): - self.global_parameter[k] = v + else: + idx = [self.params.index(key)] + if default_bounds: + if lb is None: + lb = [self.lb[i] for i in idx] + if ub is None: + ub = [self.lb[i] for i in idx] - return gp + p = self.parameter.add(key, value, var=var, lb=lb, ub=ub) + p.is_global = True + + return p @staticmethod def _prep(param_len, val): diff --git a/src/nmreval/fit/parameter.py b/src/nmreval/fit/parameter.py index fcaa284..6db56a3 100644 --- a/src/nmreval/fit/parameter.py +++ b/src/nmreval/fit/parameter.py @@ -1,124 +1,170 @@ from __future__ import annotations -from numbers import Number +import re from itertools import count - +from io import StringIO import numpy as np class Parameters(dict): - count = count() + parameter_counter = count() + # is one global namespace a good idea? + namespace: dict = {} - def __str__(self): - return 'Parameters:\n' + '\n'.join([str(k)+': '+str(v) for k, v in self.items()]) + def __init__(self): + super().__init__() + self._mapping: dict = {} - def __getitem__(self, item): - if isinstance(item, (list, tuple, np.ndarray)): - values = [] - for item_i in item: - values.append(super().__getitem__(item_i)) - return values + def __str__(self) -> str: + return 'Parameters:\n' + '\n'.join([f'{k}: {v}' for k, v in self.items()]) + + def __getitem__(self, item) -> Parameter: + if item in self._mapping: + return super().__getitem__(self._mapping[item]) else: return super().__getitem__(item) + def __setitem__(self, key, value): + self.add_parameter(key, value) + + def __contains__(self, item): + for v in self.values(): + if item == v.name: + return True + + return False + + def add(self, + name: str, + value: str | float | int = None, + *, + var: bool = True, + lb: float = -np.inf, ub: float = np.inf) -> Parameter: + + par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub) + key = f'p{next(Parameters.parameter_counter)}' + + self.add_parameter(key, par) + + return par + + def add_parameter(self, key: str, parameter: Parameter): + self._mapping[parameter.name] = key + super().__setitem__(key, parameter) + + parameter.eval_allowed = False + self.namespace[key] = parameter.value + parameter.namespace = self.namespace + parameter.eval_allowed = True + + self.update_namespace() + + def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter): + self.add_parameter(key_in, parameter) + for k, v in self._mapping.items(): + if v == key_out: + self._mapping[k] = key_in + break + + if key_out in self.namespace: + del self.namespace[key_out] + + def fix(self): + for v in self.keys(): + v._value = v.value + v.namespace = {} + @staticmethod - def _prep_bounds(val, p_len: int) -> list: - # helper function to ensure that bounds and variable are of parameter shape - if isinstance(val, (Number, bool)) or val is None: - return [val] * p_len + def reset(): + Parameters.namespace = {} - elif len(val) == p_len: - return val - - elif len(val) == 1: - return [val[0]] * p_len - - else: - raise ValueError('Input {} has wrong dimensions'.format(val)) - - def add_parameter(self, param, var=None, lb=None, ub=None): - if isinstance(param, Number): - param = [param] - - p_len = len(param) - - # make list if only single value is given - var = self._prep_bounds(var, p_len) - lb = self._prep_bounds(lb, p_len) - ub = self._prep_bounds(ub, p_len) - - new_keys = [] - for i in range(p_len): - new_idx = next(self.count) - new_keys.append(new_idx) - - self[new_idx] = Parameter(param[i], var=var[i], lb=lb[i], ub=ub[i]) - - return new_keys - - def copy(self): - p = Parameters() + def get_key(self, name: str) -> str | None: for k, v in self.items(): - p[k] = Parameter(v.value, var=v.var, lb=v.lb, ub=v.ub) + if name == v.name: + return k - if len(p) == 0: - return p - - max_k = max(p.keys()) - c = next(p.count) - while c < max_k: - c = next(p.count) - - return p + return def get_state(self): return {k: v.get_state() for k, v in self.items()} + def update_namespace(self): + for p in self.values(): + try: + p.value + except NameError: + expression = p._expr_disp + for n, k in self._mapping.items(): + expression, num_replaced = re.subn(re.escape(n), k, expression) + + p._expr = expression + class Parameter: """ Container for one parameter """ - __slots__ = ['name', 'value', 'error', 'init_val', 'var', 'lb', 'ub', 'scale', 'function'] - def __init__(self, value: float, var: bool = True, lb: float = -np.inf, ub: float = np.inf): - self.lb = lb if lb is not None else -np.inf - self.ub = ub if ub is not None else np.inf + # TODO Parameter should know its own key + def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf): + self._value: float | None = None + self.var: bool = bool(var) if var is not None else True + self.error: None | float = None if self.var is False else 0.0 + self.name: str = name + self.function: str = "" - if self.lb <= value <= self.ub: - self.value = value + self.lb: None | float = lb if lb is not None else -np.inf + self.ub: float | None = ub if ub is not None else np.inf + self.namespace: dict = {} + self.eval_allowed: bool = True + self._expr: None | str = None + self._expr_disp: None | str = None + self.is_global = False + + if isinstance(value, str): + self._expr_disp = value + self._expr = value + self.var = False else: - print(value, self.lb, self.ub) - raise ValueError('Value of parameter is outside bounds') + if self.lb <= value <= self.ub: + self._value = value + else: + raise ValueError('Value of parameter is outside bounds') - self.init_val = value + self.init_val = value - with np.errstate(divide='ignore'): - # throws RuntimeWarning for zeros - self.scale = 10**(np.floor(np.log10(np.abs(self.value)))) + with np.errstate(divide='ignore'): + # throws RuntimeWarning for zeros + self.scale = 10**(np.floor(np.log10(np.abs(self.value)))) - if self.scale == 0: - self.scale = 1. + if self.scale == 0: + self.scale = 1. - self.var = bool(var) if var is not None else True - self.error = None if self.var is False else 0.0 - self.name = '' - self.function = '' - - def __str__(self): - start = '' + def __str__(self) -> str: + start = StringIO() if self.name: if self.function: - start = f'{self.name} ({self.function}): ' + start.write(f"{self.name} ({self.function})") else: - start = self.name + ': ' + start.write(self.name) + + if self.is_global: + start.write("*") + + start.write(": ") if self.var: - return start + f'{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}' + start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}") else: - return start + f'{self.value:} (fixed)' + start.write(f"{self.value:.4g}") + if self._expr is None: + start.write(" (fixed)") + else: + start.write(f" (calc: {self._expr_disp})") - def __add__(self, other: Parameter | float) -> float: + return start.getvalue() + + def __add__(self, other: Parameter | float | int) -> float: if isinstance(other, (float, int)): return self.value + other elif isinstance(other, Parameter): @@ -128,30 +174,39 @@ class Parameter: return self.__add__(other) @property - def scaled_value(self): + def scaled_value(self) -> float: return self.value / self.scale @scaled_value.setter - def scaled_value(self, value): - self.value = value * self.scale + def scaled_value(self, value: float) -> None: + self._value = value * self.scale @property - def scaled_error(self): - if self.error is None: - return self.error - else: + def value(self) -> float | None: + if self._value is not None: + return self._value + + if self._expr is not None and self.eval_allowed: + return eval(self._expr, {}, self.namespace) + + return + + @property + def scaled_error(self) -> None | float: + if self.error is not None: return self.error / self.scale + return + @scaled_error.setter - def scaled_error(self, value): + def scaled_error(self, value) -> None: self.error = value * self.scale - def get_state(self): - - return {slot: getattr(self, slot) for slot in self.__slots__} + def get_state(self) -> dict: + return {slot: getattr(self, slot) for slot in self.__slots__} @staticmethod - def set_state(state: dict): + def set_state(state: dict) -> Parameter: par = Parameter(state.pop('value')) for k, v in state.items(): setattr(par, k, v) @@ -159,9 +214,28 @@ class Parameter: return par @property - def full_name(self): + def full_name(self) -> str: name = self.name if self.function: - name += ' (' + self.function + ')' + name += f" ({self.function})" return name + + def copy(self) -> Parameter: + if self._expr: + val = self._expr_disp + else: + val = self._value + para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub) + para_copy._expr = self._expr + para_copy.namespace = self.namespace + para_copy.is_global = self.is_global + para_copy.error = self.error + para_copy.function = self.function + + return para_copy + + def fix(self): + self._value = self.value + self.namespace = {} + diff --git a/src/nmreval/fit/result.py b/src/nmreval/fit/result.py index a2a83d1..b460ad8 100644 --- a/src/nmreval/fit/result.py +++ b/src/nmreval/fit/result.py @@ -2,6 +2,7 @@ from __future__ import annotations import re from collections import OrderedDict +from io import StringIO from pathlib import Path from typing import Any @@ -186,7 +187,7 @@ class FitResult(Points): nice_name = m.group(1) if func_number in split_funcs: nice_func = split_funcs[func_number] - + pvalue.fix() pvalue.name = nice_name pvalue.function = nice_func parameter_dic[pname] = pvalue @@ -223,27 +224,30 @@ class FitResult(Points): return self.nobs-self.nvar def pprint(self, statistics=True, correlations=True): - print('Fit result:') - print(' model :', self.name) - print(' #data :', self.nobs) - print(' #var :', self.nvar) - print('\nParameter') - print(self.parameter_string()) + s = StringIO() + s.write('Fit result:\n') + s.write(f' model : {self.name}\n') + s.write(f' #data : {self.nobs}\n') + s.write(f' #var : {self.nvar}\n') + s.write('\nParameter\n') + s.write(self.parameter_string()) if statistics: - print('Statistics') + s.write('\nStatistics\n') for k, v in self.statistics.items(): - print(f' {k} : {v:.4f}') + s.write(f' {k} : {v:.4f}\n') if correlations and self.correlation is not None: - print('\nCorrelation (partial corr.)') - print(self._correlation_string()) - print() + s.write('\nCorrelation (partial corr.)\n') + s.write(self._correlation_string()) + s.write('\n') + + print(s.getvalue()) def parameter_string(self): ret_val = '' - for pval in self.parameter.values(): + for pkey, pval in self.parameter.items(): ret_val += convert(str(pval), old='tex', new='str') + '\n' if self.fun_kwargs: @@ -255,9 +259,7 @@ class FitResult(Points): def _correlation_string(self): ret_val = '' for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list(): - ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'), - convert(p_j, old='tex', new='str'), - corr_ij, pcorr_ij) + ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n" return ret_val def correlation_list(self, limit=0.1): diff --git a/src/nmreval/models/spectrum.py b/src/nmreval/models/spectrum.py index 229cf66..30b6917 100644 --- a/src/nmreval/models/spectrum.py +++ b/src/nmreval/models/spectrum.py @@ -35,8 +35,8 @@ class Gaussian: class Lorentzian: type = 'Spectrum' name = 'Lorentzian' - equation = 'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}' - params = ['A', '\mu', 'w', 'A_{0}'] + equation = r'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}' + params = ['A', r'\mu', 'w', 'A_{0}'] ext_params = None bounds = [(0, None), (None, None), (0, None), (None, None)] @@ -62,9 +62,9 @@ class Lorentzian: class PseudoVoigt: type = 'Spectrum' name = 'Pseudo Voigt' - equation = 'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \ - '(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}' - params = ['A', 'R', '\mu', 'w', 'A_{0}'] + equation = r'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \ + r'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}' + params = ['A', 'R', r'\mu', 'w', 'A_{0}'] ext_params = None bounds = [(0, None), (0, 1), (None, None), (0, None)] diff --git a/src/resources/_ui/fitmodelwidget.ui b/src/resources/_ui/fitmodelwidget.ui index 02069fb..ffc0b93 100755 --- a/src/resources/_ui/fitmodelwidget.ui +++ b/src/resources/_ui/fitmodelwidget.ui @@ -7,7 +7,7 @@ 0 0 365 - 78 + 66 @@ -62,7 +62,7 @@ - + 0 0 @@ -78,19 +78,6 @@ - - - - Qt::Horizontal - - - - 40 - 20 - - - - @@ -105,19 +92,6 @@ - - - - - - - QToolButton::InstantPopup - - - Qt::RightArrow - - -