Merge branch 'fit_constraints'

# Conflicts:
#	src/gui_qt/main/management.py
This commit is contained in:
Dominik Demuth 2023-09-19 12:39:32 +02:00
commit 04037d6b4d
18 changed files with 574 additions and 619 deletions

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'resources/_ui/fitmodelwidget.ui' # Form implementation generated from reading ui file 'src/resources/_ui/fitmodelwidget.ui'
# #
# Created by: PyQt5 UI code generator 5.12.3 # Created by: PyQt5 UI code generator 5.15.9
# #
# WARNING! All changes made in this file will be lost! # WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
from PyQt5 import QtCore, QtGui, QtWidgets from PyQt5 import QtCore, QtGui, QtWidgets
@ -13,7 +14,7 @@ from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_FitParameter(object): class Ui_FitParameter(object):
def setupUi(self, FitParameter): def setupUi(self, FitParameter):
FitParameter.setObjectName("FitParameter") FitParameter.setObjectName("FitParameter")
FitParameter.resize(365, 78) FitParameter.resize(365, 66)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding)
sizePolicy.setHorizontalStretch(0) sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0) sizePolicy.setVerticalStretch(0)
@ -36,7 +37,7 @@ class Ui_FitParameter(object):
self.parametername.setObjectName("parametername") self.parametername.setObjectName("parametername")
self.horizontalLayout_2.addWidget(self.parametername) self.horizontalLayout_2.addWidget(self.parametername)
self.parameter_line = LineEdit(FitParameter) self.parameter_line = LineEdit(FitParameter)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
sizePolicy.setHorizontalStretch(0) sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0) sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth()) sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth())
@ -44,20 +45,12 @@ class Ui_FitParameter(object):
self.parameter_line.setText("") self.parameter_line.setText("")
self.parameter_line.setObjectName("parameter_line") self.parameter_line.setObjectName("parameter_line")
self.horizontalLayout_2.addWidget(self.parameter_line) self.horizontalLayout_2.addWidget(self.parameter_line)
spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
self.horizontalLayout_2.addItem(spacerItem)
self.fixed_check = QtWidgets.QCheckBox(FitParameter) self.fixed_check = QtWidgets.QCheckBox(FitParameter)
self.fixed_check.setObjectName("fixed_check") self.fixed_check.setObjectName("fixed_check")
self.horizontalLayout_2.addWidget(self.fixed_check) self.horizontalLayout_2.addWidget(self.fixed_check)
self.global_checkbox = QtWidgets.QCheckBox(FitParameter) self.global_checkbox = QtWidgets.QCheckBox(FitParameter)
self.global_checkbox.setObjectName("global_checkbox") self.global_checkbox.setObjectName("global_checkbox")
self.horizontalLayout_2.addWidget(self.global_checkbox) self.horizontalLayout_2.addWidget(self.global_checkbox)
self.toolButton = QtWidgets.QToolButton(FitParameter)
self.toolButton.setText("")
self.toolButton.setPopupMode(QtWidgets.QToolButton.InstantPopup)
self.toolButton.setArrowType(QtCore.Qt.RightArrow)
self.toolButton.setObjectName("toolButton")
self.horizontalLayout_2.addWidget(self.toolButton)
self.verticalLayout.addLayout(self.horizontalLayout_2) self.verticalLayout.addLayout(self.horizontalLayout_2)
self.frame = QtWidgets.QFrame(FitParameter) self.frame = QtWidgets.QFrame(FitParameter)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum)

View File

@ -8,6 +8,7 @@ from pyqtgraph import mkPen
from nmreval.data.points import Points from nmreval.data.points import Points
from nmreval.data.signals import Signal from nmreval.data.signals import Signal
from nmreval.lib.logger import logger
from nmreval.utils.text import convert from nmreval.utils.text import convert
from nmreval.data.bds import BDS from nmreval.data.bds import BDS
from nmreval.data.dsc import DSC from nmreval.data.dsc import DSC
@ -356,7 +357,7 @@ class ExperimentContainer(QtCore.QObject):
elif mode in ['imag', 'all'] and self.plot_imag is not None: elif mode in ['imag', 'all'] and self.plot_imag is not None:
self.plot_imag.set_symbol(symbol=symbol, size=size, color=color) self.plot_imag.set_symbol(symbol=symbol, size=size, color=color)
else: else:
print('Updating symbol failed for ' + str(self.id)) logger.warning(f'Updating symbol failed for {self.id}')
def setLine(self, *, width=None, style=None, color=None, mode='real'): def setLine(self, *, width=None, style=None, color=None, mode='real'):
if mode in ['real', 'all']: if mode in ['real', 'all']:
@ -368,7 +369,7 @@ class ExperimentContainer(QtCore.QObject):
elif mode in ['imag', 'all'] and self.plot_imag is not None: elif mode in ['imag', 'all'] and self.plot_imag is not None:
self.plot_imag.set_line(width=width, style=style, color=color) self.plot_imag.set_line(width=width, style=style, color=color)
else: else:
print('Updating line failed for ' + str(self.id)) logger.warning(f'Updating line failed for {self.id}')
def update_property(self, key1: str, key2: str, value: Any): def update_property(self, key1: str, key2: str, value: Any):
keykey = key2.split() keykey = key2.split()

View File

@ -1,3 +1,4 @@
from nmreval.lib.logger import logger
from nmreval.math import apodization from nmreval.math import apodization
from nmreval.lib.importer import find_models from nmreval.lib.importer import find_models
from nmreval.utils.text import convert from nmreval.utils.text import convert
@ -67,7 +68,7 @@ class EditSignalWidget(QtWidgets.QWidget, Ui_Form):
self.do_something.emit(sender, (ph0, ph1, pvt)) self.do_something.emit(sender, (ph0, ph1, pvt))
else: else:
print('You should never reach this by accident.') logger.warning(f'You should never reach this by accident, invalid sender {sender!r}')
@QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged') @QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged')
def change_apodization(self, index): def change_apodization(self, index):

View File

@ -19,19 +19,20 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
super().__init__(parent) super().__init__(parent)
self.setupUi(self) self.setupUi(self)
self.parametername.setText(label + ' ') self.name = label
self.parametername.setText(convert(label) + ' ')
validator = QtGui.QDoubleValidator()
self.parameter_line.setValidator(validator)
self.parameter_line.setText('1') self.parameter_line.setText('1')
self.parameter_line.setMaximumWidth(240) self.parameter_line.setMaximumWidth(160)
self.lineEdit.setMaximumWidth(60) self.lineEdit.setMaximumWidth(100)
self.lineEdit_2.setMaximumWidth(60) self.lineEdit_2.setMaximumWidth(100)
self.label_3.setText(f'< {label} <') self.label_3.setText(f'< {convert(label)} <')
self.checkBox.stateChanged.connect(self.enableBounds) self.checkBox.stateChanged.connect(self.enableBounds)
self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit()) self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit())
self.parameter_line.editingFinished.connect(self.update_parameter)
self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self)) self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self))
self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None)) self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None))
self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text())) self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text()))
@ -40,18 +41,12 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
if fixed: if fixed:
self.fixed_check.hide() self.fixed_check.hide()
self.menu = QtWidgets.QMenu(self)
self.add_links()
self.is_linked = None
self.parameter_pos = None self.parameter_pos = None
self.func_idx = None self.func_idx = None
self._linetext = '1' self._linetext = '1'
@property self.menu = QtWidgets.QMenu(self)
def name(self):
return convert(self.parametername.text().strip(), old='html', new='str')
def set_parameter_string(self, p: str): def set_parameter_string(self, p: str):
self.parameter_line.setText(p) self.parameter_line.setText(p)
@ -71,11 +66,6 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None, def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None,
fixed: bool = None, glob: bool = None): fixed: bool = None, glob: bool = None):
if p is None:
# bad hack: linked parameter return (None, linked parameter)
# if p is None -> parameter is linked to argument given by bds
self.link_parameter(linkto=bds)
else:
ptext = f'{p:.4g}' ptext = f'{p:.4g}'
self.set_parameter_string(ptext) self.set_parameter_string(ptext)
@ -90,19 +80,10 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked) self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked)
def get_parameter(self): def get_parameter(self):
if self.is_linked:
try:
p = float(self._linetext)
except ValueError:
p = 1.0
else:
try: try:
p = float(self.parameter_line.text().replace(',', '.')) p = float(self.parameter_line.text().replace(',', '.'))
except ValueError: except ValueError:
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value', p = self.parameter_line.text().replace(',', '.')
f'{self.parametername.text()} contains invalid values',
QtWidgets.QMessageBox.Cancel)
return None
if self.checkBox.isChecked(): if self.checkBox.isChecked():
try: try:
@ -119,75 +100,27 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
bounds = (lb, rb) bounds = (lb, rb)
return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked(), self.is_linked return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked()
@QtCore.pyqtSlot(bool) @QtCore.pyqtSlot(bool)
def set_fixed(self, state: bool): def set_fixed(self, state: bool):
# self.global_checkbox.setVisible(not state) # self.global_checkbox.setVisible(not state)
self.frame.setVisible(not state) self.frame.setVisible(not state)
def add_links(self, parameter: dict = None):
if parameter is None:
parameter = {}
self.menu.clear()
ac = QtWidgets.QAction('Link to...', self)
ac.triggered.connect(self.link_parameter)
self.menu.addAction(ac)
for model_key, model_funcs in parameter.items():
m = QtWidgets.QMenu('Model ' + model_key, self)
for func_name, func_params in model_funcs.items():
m2 = QtWidgets.QMenu(func_name, m)
for p_name, idx in func_params:
ac = QtWidgets.QAction(p_name, m2)
ac.setData((model_key, *idx))
ac.triggered.connect(self.link_parameter)
m2.addAction(ac)
m.addMenu(m2)
self.menu.addMenu(m)
self.toolButton.setMenu(self.menu)
@QtCore.pyqtSlot() @QtCore.pyqtSlot()
def link_parameter(self, linkto=None): def update_parameter(self):
if linkto is None: new_value = self.parameter_line.text()
action = self.sender() if not new_value:
else: self.parameter_line.setText('1')
action = False
for m in self.menu.actions():
if m.menu():
for a in m.menu().actions():
if a.data() == linkto:
action = a
break
if action:
break
if (self.func_idx, self.parameter_pos) == action.data():
return
try: try:
new_text = f'Linked to {action.parentWidget().title()}.{action.text()}' float(new_value)
self._linetext = self.parameter_line.text() is_text = False
self.parameter_line.setText(new_text) except ValueError:
self.parameter_line.setEnabled(False) is_text = True
self.global_checkbox.hide() self.global_checkbox.setCheckState(False)
self.global_checkbox.blockSignals(True)
self.global_checkbox.setCheckState(QtCore.Qt.Checked)
self.global_checkbox.blockSignals(False)
self.frame.hide()
self.is_linked = action.data()
except AttributeError: self.set_fixed(is_text)
self.parameter_line.setText(self._linetext)
self.parameter_line.setEnabled(True)
if self.fixed_check.isEnabled():
self.global_checkbox.show()
self.frame.show()
self.is_linked = None
self.state_changed.emit()
class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog): class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog):
@ -282,8 +215,17 @@ class FitModelTree(QtWidgets.QTreeWidget):
idx = item.data(0, self.counterRole) idx = item.data(0, self.counterRole)
self.itemRemoved.emit(idx) self.itemRemoved.emit(idx)
def add_function(self, idx: int, cnt: int, op: int, name: str, color: QtGui.QColor | str | tuple, def add_function(self,
parent: QtWidgets.QTreeWidgetItem = None, children: list = None, active: bool = True, **kwargs): idx: int,
cnt: int,
op: int,
name: str,
color: QtGui.QColor | str | tuple,
parent: QtWidgets.QTreeWidgetItem = None,
children: list = None,
active: bool = True,
param_names: list[str] = None,
**kwargs):
""" """
Add function to tree and dictionary of functions. Add function to tree and dictionary of functions.
""" """
@ -298,6 +240,10 @@ class FitModelTree(QtWidgets.QTreeWidget):
it.setData(0, self.counterRole, cnt) it.setData(0, self.counterRole, cnt)
it.setData(0, self.operatorRole, op) it.setData(0, self.operatorRole, op)
it.setText(0, name) it.setText(0, name)
if param_names is not None:
it.setToolTip(0,
'Parameter names:\n' +
'\n'.join(f'{pn}({cnt})' for pn in param_names))
it.setForeground(0, QtGui.QBrush(color)) it.setForeground(0, QtGui.QBrush(color))
it.setIcon(0, get_icon(self.icons[op])) it.setIcon(0, get_icon(self.icons[op]))

View File

@ -1,5 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional
from nmreval.fit.parameter import Parameter
from nmreval.utils.text import convert from nmreval.utils.text import convert
from ..Qt import QtWidgets, QtCore, QtGui from ..Qt import QtWidgets, QtCore, QtGui
@ -62,8 +65,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.glob_values = [1] * len(func.params) self.glob_values = [1] * len(func.params)
for k, v in enumerate(func.params): for k, v in enumerate(func.params):
name = convert(v) widgt = FitModelWidget(label=v, parent=self.scrollwidget)
widgt = FitModelWidget(label=name, parent=self.scrollwidget)
widgt.parameter_pos = k widgt.parameter_pos = k
widgt.func_idx = idx widgt.func_idx = idx
try: try:
@ -83,7 +85,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.global_parameter.append(widgt) self.global_parameter.append(widgt)
self.scrollwidget.layout().addWidget(widgt) self.scrollwidget.layout().addWidget(widgt)
widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2) widgt2 = ParameterSingleWidget(name=v, parent=self.scrollwidget2)
widgt2.valueChanged.connect(self.change_single_parameter) widgt2.valueChanged.connect(self.change_single_parameter)
widgt2.removeSingleValue.connect(self.change_single_parameter) widgt2.removeSingleValue.connect(self.change_single_parameter)
widgt2.installEventFilter(self) widgt2.installEventFilter(self)
@ -115,20 +117,22 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
self.scrollwidget.layout().addStretch(1) self.scrollwidget.layout().addStretch(1)
self.scrollwidget2.layout().addStretch(1) self.scrollwidget2.layout().addStretch(1)
def set_links(self, parameter): # def set_links(self, parameter):
for w in self.global_parameter: # for w in self.global_parameter:
if isinstance(w, FitModelWidget): # if isinstance(w, FitModelWidget):
w.add_links(parameter) # w.add_links(parameter)
@QtCore.pyqtSlot(str) @QtCore.pyqtSlot(str)
def change_global_parameter(self, value: str, idx: int = None): def change_global_parameter(self, value: str, idx: int = None):
if idx is None: if idx is None:
idx = self.global_parameter.index(self.sender()) idx = self.global_parameter.index(self.sender())
self.glob_values[idx] = float(value) # self.glob_values[idx] = float(value)
self.glob_values[idx] = value
if self.data_values[self.comboBox.currentData()][idx] is None: if self.data_values[self.comboBox.currentData()][idx] is None:
self.data_parameter[idx].blockSignals(True) self.data_parameter[idx].blockSignals(True)
self.data_parameter[idx].value = float(value) # self.data_parameter[idx].value = float(value)
self.data_parameter[idx].value = value
self.data_parameter[idx].blockSignals(False) self.data_parameter[idx].blockSignals(False)
@QtCore.pyqtSlot(str, object) @QtCore.pyqtSlot(str, object)
@ -171,7 +175,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
# disable single parameter if it is set global, enable if global is unset # disable single parameter if it is set global, enable if global is unset
widget = self.sender() widget = self.sender()
idx = self.global_parameter.index(widget) idx = self.global_parameter.index(widget)
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None) enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked)
self.data_parameter[idx].setEnabled(enable) self.data_parameter[idx].setEnabled(enable)
def select_next_preview(self, direction): def select_next_preview(self, direction):
@ -204,64 +208,50 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
if sid not in self.data_values: if sid not in self.data_values:
self.data_values[sid] = [None] * len(self.data_parameter) self.data_values[sid] = [None] * len(self.data_parameter)
def get_parameter(self, use_func=None): def get_parameter(self, use_func=None) -> tuple[dict, list]:
bds = [] bds = []
is_global = [] is_global = []
is_fixed = [] is_fixed = []
globs = [] param_general = []
is_linked = []
for g in self.global_parameter: for g in self.global_parameter:
if isinstance(g, FitModelWidget): if isinstance(g, FitModelWidget):
p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter() p_i, bds_i, fixed_i, global_i = g.get_parameter()
parameter_i = Parameter(name=g.name, value=p_i, lb=bds_i[0], ub=bds_i[1], var=fixed_i)
param_general.append(parameter_i)
globs.append(p_i)
bds.append(bds_i) bds.append(bds_i)
is_fixed.append(fixed_i) is_fixed.append(fixed_i)
is_global.append(global_i) is_global.append(global_i)
is_linked.append(link_i)
lb, ub = list(zip(*bds))
data_parameter = {} data_parameter = {}
if use_func is None: if use_func is None:
use_func = list(self.data_values.keys()) use_func = list(self.data_values.keys())
global_p = None
for sid, parameter in self.data_values.items(): for sid, parameter in self.data_values.items():
if sid not in use_func: if sid not in use_func:
continue continue
kw_p = {} kw_p = {}
p = [] p = []
if global_p is None:
global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []}
for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)): for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)):
if isinstance(g, FitModelWidget): if isinstance(g, FitModelWidget):
if (p_i is None) or is_global[i]: if (p_i is None) or is_global[i]:
p.append(globs[i]) # set has no oen value
if is_global[i]: p.append(param_general[i].copy())
if i not in global_p['idx']:
global_p['p'].append(globs[i])
global_p['idx'].append(i)
global_p['var'].append(is_fixed[i])
global_p['ub'].append(ub[i])
global_p['lb'].append(lb[i])
else: else:
p.append(p_i) lb, ub = bds[i]
try: try:
if p[i] > ub[i]: if not (lb < p_i < ub):
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') raise ValueError(f'Parameter {g.name} is outside bounds ({lb}, {ub})')
except TypeError: except TypeError:
pass pass
try: # create Parameter
if p[i] < lb[i]: p.append(
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})') Parameter(name=g.name, value=p_i, lb=lb, ub=ub, var=is_fixed[i])
except TypeError: )
pass
else: else:
if p_i is None: if p_i is None:
@ -273,7 +263,15 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
data_parameter[sid] = (p, kw_p) data_parameter[sid] = (p, kw_p)
return data_parameter, lb, ub, is_fixed, global_p, is_linked global_parameter = []
for param, global_flag in zip(param_general, is_global):
if global_flag:
param.is_global = True
global_parameter.append(param)
else:
global_parameter.append(None)
return data_parameter, global_parameter
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int: def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter)) num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter))
@ -304,12 +302,12 @@ class ParameterSingleWidget(QtWidgets.QWidget):
self._init_ui() self._init_ui()
self._name = name self.name = name
self.label.setText(convert(name)) self.label.setText(convert(name))
self.label.setToolTip('If this is bold then this parameter is only for this data. ' self.label.setToolTip('If this is bold then this parameter is only for this data. '
'Otherwise, the general parameter is used and displayed') 'Otherwise, the general parameter is used and displayed')
self.value_line.setValidator(QtGui.QDoubleValidator()) # self.value_line.setValidator(QtGui.QDoubleValidator())
self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0) self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0)
self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit()) self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit())
@ -343,9 +341,10 @@ class ParameterSingleWidget(QtWidgets.QWidget):
@value.setter @value.setter
def value(self, val): def value(self, val):
self.value_line.setText(f'{float(val):.5g}') # self.value_line.setText(f'{float(val):.5g}')
self.value_line.setText(f'{val}')
def show_as_local_parameter(self, is_local): def show_as_local_parameter(self, is_local: bool):
if is_local: if is_local:
self.label.setStyleSheet('font-weight: bold;') self.label.setStyleSheet('font-weight: bold;')
else: else:

View File

@ -128,7 +128,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
self.newFunction.emit(idx, cnt) self.newFunction.emit(idx, cnt)
self.add_function(idx, cnt, op, name, col) self.add_function(idx, cnt, op, name, col, param_names=self.functions[idx].params)
def add_function(self, idx: int, cnt: int, op: int, def add_function(self, idx: int, cnt: int, op: int,
name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs): name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs):
@ -141,6 +141,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
qcolor = QtGui.QColor.fromRgbF(*color) qcolor = QtGui.QColor.fromRgbF(*color)
else: else:
qcolor = QtGui.QColor(color) qcolor = QtGui.QColor(color)
self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs) self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs)
f = self.functions[idx] f = self.functions[idx]

View File

@ -9,6 +9,9 @@ import numpy as np
from pyqtgraph import mkPen from pyqtgraph import mkPen
from nmreval.fit._meta import MultiModel, ModelFactory from nmreval.fit._meta import MultiModel, ModelFactory
from nmreval.fit.data import Data
from nmreval.fit.model import Model
from nmreval.fit.parameter import Parameters
from nmreval.fit.result import FitResult from nmreval.fit.result import FitResult
from .fit_forms import FitTableWidget from .fit_forms import FitTableWidget
@ -116,7 +119,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
# collect parameter names etc. to allow linkage # collect parameter names etc. to allow linkage
self._func_list[self._current_model] = self.functionwidget.get_parameter_list() self._func_list[self._current_model] = self.functionwidget.get_parameter_list()
dialog.set_links(self._func_list) # dialog.set_links(self._func_list)
# show same tab (general parameter/Data parameter) # show same tab (general parameter/Data parameter)
tab_idx = 0 tab_idx = 0
@ -219,57 +222,49 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
def _prepare(self, model: list, function_use: list = None, def _prepare(self, model: list, function_use: list = None,
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
if parameter is None: if parameter is None:
parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [], parameter = {
'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []}, 'data_parameter': {},
'links': [], 'color': []} 'global_parameter': [],
'links': [],
'color': [],
}
for i, f in enumerate(model): for i, f in enumerate(model):
if not f['active']: if not f['active']:
continue continue
try: try:
p, lb, ub, var, glob, links = self.param_widgets[f['cnt']].get_parameter(function_use) p, glob = self.param_widgets[f['cnt']].get_parameter(function_use)
except ValueError as e: except ValueError as e:
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e), _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e),
QtWidgets.QMessageBox.Ok) QtWidgets.QMessageBox.Ok)
return {}, -1 return {}, -1
p_len = len(parameter['lb']) parameter['color'].append(f['color'])
parameter['global_parameter'].extend(glob)
parameter['lb'] += lb
parameter['ub'] += ub
parameter['var'] += var
parameter['links'] += links
parameter['color'] += [f['color']]
cnt = f['cnt'] cnt = f['cnt']
for p_k, v_k in p.items(): for p_k, v_k in p.items():
if add_idx: if add_idx:
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
else: else:
kw_k = v_k[1] kw_k = v_k[1]
if p_k in parameter['parameter']: if p_k in parameter['data_parameter']:
params, kw = parameter['parameter'][p_k] params, kw = parameter['data_parameter'][p_k]
params += v_k[0] params += v_k[0]
kw.update(kw_k) kw.update(kw_k)
else: else:
parameter['parameter'][p_k] = (v_k[0], kw_k) parameter['data_parameter'][p_k] = (v_k[0], kw_k)
for g_k, g_v in glob.items():
if g_k != 'idx':
parameter['glob'][g_k] += g_v
else:
parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v]
if add_idx: if add_idx:
cnt += 1 cnt += 1
if f['children']: if f['children']:
# recurse for children # recurse for children
child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) _, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt)
return parameter, cnt return parameter, cnt
@ -280,30 +275,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
data = self.data_table.collect_data(default=self.default_combobox.currentData()) data = self.data_table.collect_data(default=self.default_combobox.currentData())
func_dict = {} func_dict = {}
for k, mod in self.models.items(): for model_name, model_parameter in self.models.items():
func, order, param_len = ModelFactory.create_from_list(mod) func, order, param_len = ModelFactory.create_from_list(model_parameter)
if func is None: if func is None:
continue continue
if k in data: func = Model(func)
parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel))
if model_name in data:
parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel))
if parameter is None: if parameter is None:
return return
for (data_parameter, _) in parameter['data_parameter'].values():
for pname, param in zip(func.params, data_parameter):
param.name = pname
if self._complex[model_name] is not None:
for p_k, p_v in parameter['data_parameter'].items():
p_v[1].update({'complex_mode': self._complex[model_name]})
parameter['data_parameter'][p_k] = p_v[0], p_v[1]
for pname, param_value in zip(func.params, parameter['global_parameter']):
if param_value is not None:
param_value.name = pname
func.set_global_parameter(param_value)
parameter['func'] = func parameter['func'] = func
parameter['order'] = order parameter['order'] = order
parameter['len'] = param_len parameter['len'] = param_len
parameter['complex'] = self._complex[k] parameter['complex'] = self._complex[model_name]
if self._complex[k] is not None:
for p_k, p_v in parameter['parameter'].items():
p_v[1].update({'complex_mode': self._complex[k]})
parameter['parameter'][p_k] = p_v[0], p_v[1]
func_dict[k] = parameter func_dict[model_name] = parameter
replaceable = [] replaceable = []
for k, v in func_dict.items(): for model_name, v in func_dict.items():
for i, link_i in enumerate(v['links']): for i, link_i in enumerate(v['links']):
if link_i is None: if link_i is None:
continue continue
@ -334,7 +342,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
QtWidgets.QMessageBox.Ok) QtWidgets.QMessageBox.Ok)
return return
replaceable.append((k, i, rep_model, repl_idx)) replaceable.append((model_name, i, rep_model, repl_idx))
replace_value = None replace_value = None
for p_k in f['parameter'].values(): for p_k in f['parameter'].values():
@ -412,31 +420,37 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
def make_previews(self, x, models_parameters: dict): def make_previews(self, x, models_parameters: dict):
self.preview_lines = [] self.preview_lines = []
# needed to create namespace
param_dict = Parameters()
cnt = 0
for model in models_parameters.values():
f = model['func']
for parameter_list in model['data_parameter'].values():
for i, p_value in enumerate(parameter_list[0]):
p_value.name = f.params[i]
param_dict.add_parameter(f'a{cnt}', p_value)
cnt += 1
for k, model in models_parameters.items(): for k, model in models_parameters.items():
f = model['func'] f = model['func']
is_complex = self._complex[k] is_complex = self._complex[k]
parameters = model['parameter'] parameters = model['data_parameter']
color = model['color'] color = model['color']
seen_parameter = []
for p, kwargs in parameters.values(): for p, kwargs in parameters.values():
if (p, kwargs) in seen_parameter: p_value = [pp.value for pp in p]
# plot only previews with different parameter
continue
seen_parameter.append((p, kwargs))
if is_complex is not None: if is_complex is not None:
y = f.func(x, *p, complex_mode=is_complex, **kwargs) y = f.func(x, *p_value, complex_mode=is_complex, **kwargs)
if np.iscomplexobj(y): if np.iscomplexobj(y):
self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3)))
self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3)))
else: else:
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
else: else:
y = f.func(x, *p, **kwargs) y = f.func(x, *p_value, **kwargs)
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
if isinstance(f, MultiModel): if isinstance(f, MultiModel):
@ -444,7 +458,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
if is_complex is not None: if is_complex is not None:
sub_kwargs.update({'complex_mode': is_complex}) sub_kwargs.update({'complex_mode': is_complex})
for i, s in enumerate(f.subs(x, *p, **sub_kwargs)): for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)):
pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i])) pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i]))
if np.iscomplexobj(s): if np.iscomplexobj(s):
self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i)) self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i))
@ -452,15 +466,17 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
else: else:
self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i))
param_dict.clear()
return self.preview_lines return self.preview_lines
def set_parameter(self, parameter: dict[str, FitResult]): def set_parameter(self, parameter: dict[str, FitResult]):
# which data uses which model # which data uses which model
data = self.data_table.collect_data(default=self.default_combobox.currentData()) data = self.data_table.collect_data(default=self.default_combobox.currentData())
for fitted_model, fitted_data in data.items():
glob_fit_parameter = [] glob_fit_parameter = []
for fitted_model, fitted_data in data.items():
for fit_id, fit_curve in parameter.items(): for fit_id, fit_curve in parameter.items():
if fit_id in fitted_data: if fit_id in fitted_data:
fit_parameter = list(fit_curve.parameter.values()) fit_parameter = list(fit_curve.parameter.values())

View File

@ -138,9 +138,7 @@ class DrawingsWidget(QtWidgets.QWidget, Ui_Form):
graph_id = self.graph_comboBox.currentData() graph_id = self.graph_comboBox.currentData()
current_lines = self.lines[graph_id] current_lines = self.lines[graph_id]
print(remove_rows)
for i in reversed(remove_rows): for i in reversed(remove_rows):
print(i)
self.tableWidget.removeRow(i) self.tableWidget.removeRow(i)
self.line_deleted.emit(current_lines[i], graph_id) self.line_deleted.emit(current_lines[i], graph_id)

View File

@ -27,7 +27,6 @@ class MdiAreaTile(QtWidgets.QMdiArea):
pos = QtCore.QPoint(0, 0) pos = QtCore.QPoint(0, 0)
for win in window_list: for win in window_list:
print(win.minimumSize())
win.setGeometry(rect) win.setGeometry(rect)
win.move(pos) win.move(pos)

View File

@ -1,110 +0,0 @@
import os.path
import json
import urllib.request
import webbrowser
import random
from ..Qt import QtGui, QtCore, QtWidgets
from .._py.pokemon import Ui_Dialog
random.seed()
class QPokemon(QtWidgets.QDialog, Ui_Dialog):
def __init__(self, number=None, parent=None):
super().__init__(parent=parent)
self.setupUi(self)
self._js = json.load(open(os.path.join(path_to_module, 'utils', 'pokemon.json'), 'r'), encoding='UTF-8')
self._id = 0
if number is not None and number in range(1, len(self._js)+1):
poke_nr = f'{number:03d}'
self._id = number
else:
poke_nr = f'{random.randint(1, len(self._js)):03d}'
self._id = int(poke_nr)
self._pokemon = None
self.show_pokemon(poke_nr)
self.label_15.linkActivated.connect(lambda x: webbrowser.open(x))
self.buttonBox.clicked.connect(self.randomize)
self.next_button.clicked.connect(self.show_next)
self.prev_button.clicked.connect(self.show_prev)
def show_pokemon(self, nr):
self._pokemon = self._js[nr]
self.setWindowTitle('Pokémon: ' + self._pokemon['Deutsch'])
for i in range(self.tabWidget.count(), -1, -1):
print('i', self.tabWidget.count(), i)
try:
self.tabWidget.widget(i).deleteLater()
except AttributeError:
pass
for n, img in self._pokemon['Bilder']:
w = QtWidgets.QWidget()
vl = QtWidgets.QVBoxLayout()
l = QtWidgets.QLabel(self)
l.setAlignment(QtCore.Qt.AlignHCenter)
pixmap = QtGui.QPixmap()
try:
pixmap.loadFromData(urllib.request.urlopen(img, timeout=0.5).read())
except IOError:
l.setText(n)
else:
sc_pixmap = pixmap.scaled(256, 256, QtCore.Qt.KeepAspectRatio)
l.setPixmap(sc_pixmap)
vl.addWidget(l)
w.setLayout(vl)
self.tabWidget.addTab(w, n)
if len(self._pokemon['Bilder']) <= 1:
self.tabWidget.tabBar().setVisible(False)
else:
self.tabWidget.tabBar().setVisible(True)
self.tabWidget.adjustSize()
self.name.clear()
keys = ['National-Dex', 'Kategorie', 'Typ', 'Größe', 'Gewicht', 'Farbe', 'Link']
label_list = [self.pokedex_nr, self.category, self.poketype, self.weight, self.height, self.color, self.info]
for (k, label) in zip(keys, label_list):
v = self._pokemon[k]
if isinstance(v, list):
v = os.path.join('', *v)
if k == 'Link':
v = '<a href={}>{}</a>'.format(v, v)
label.setText(v)
for k in ['Deutsch', 'Japanisch', 'Englisch', 'Französisch']:
v = self._pokemon[k]
self.name.addItem(k + ': ' + v)
self.adjustSize()
def randomize(self, idd):
if idd.text() == 'Retry':
new_number = f'{random.randint(1, len(self._js)):03d}'
self._id = int(new_number)
self.show_pokemon(new_number)
else:
self.close()
def show_next(self):
new_number = self._id + 1
if new_number > len(self._js):
new_number = 1
self._id = new_number
self.show_pokemon(f'{new_number:03d}')
def show_prev(self):
new_number = self._id - 1
if new_number == 0:
new_number = len(self._js)
self._id = new_number
self.show_pokemon(f'{new_number:03d}')

View File

@ -441,7 +441,7 @@ class UpperManagement(QtCore.QObject):
# all-encompassing error catch # all-encompassing error catch
try: try:
for model_id, model_p in parameter.items(): for model_id, model_p in parameter.items():
m = Model(model_p['func']) m = model_p['func']
models[model_id] = m models[model_id] = m
m_complex = model_p['complex'] m_complex = model_p['complex']
@ -450,13 +450,16 @@ class UpperManagement(QtCore.QObject):
# iterate over order of set id in active order and access parameter inside loop # iterate over order of set id in active order and access parameter inside loop
# instead of directly looping # instead of directly looping
try: try:
list_ids = list(model_p['parameter'].keys()) list_ids = list(model_p['data_parameter'].keys())
set_order = [self.active_id.index(i) for i in list_ids] set_order = [self.active_id.index(i) for i in list_ids]
except ValueError as e: except ValueError as e:
raise Exception('Getting order failed') from e raise Exception('Getting order failed') from e
for pos in set_order: for pos in set_order:
set_id = list_ids[pos] set_id = list_ids[pos]
data_i = self.data[set_id]
set_params = model_p['data_parameter'][set_id]
try: try:
data_i = self.data[set_id] data_i = self.data[set_id]
except KeyError as e: except KeyError as e:
@ -499,18 +502,12 @@ class UpperManagement(QtCore.QObject):
d.set_model(m) d.set_model(m)
try: try:
d.set_parameter(set_params[0], var=model_p['var'], d.set_parameter(set_params[0], fun_kwargs=set_params[1])
lb=model_p['lb'], ub=model_p['ub'],
fun_kwargs=set_params[1])
except Exception as e: except Exception as e:
raise Exception('Setting parameter failed') from e raise Exception('Setting parameter failed') from e
self.fitter.add_data(d) self.fitter.add_data(d)
model_globs = model_p['glob']
if model_globs:
m.set_global_parameter(**model_p['glob'])
for links_i in links: for links_i in links:
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]), self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
(models[links_i[2]], links_i[3])) (models[links_i[2]], links_i[3]))
@ -1170,7 +1167,6 @@ class UpperManagement(QtCore.QObject):
@QtCore.pyqtSlot(dict) @QtCore.pyqtSlot(dict)
def calc_relaxation(self, opts: dict): def calc_relaxation(self, opts: dict):
params = opts['pts'] params = opts['pts']
if len(params) == 4: if len(params) == 4:
if params[3]: if params[3]:

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import numpy as np import numpy as np
from .model import Model from .model import Model
from .parameter import Parameters from .parameter import Parameters, Parameter
class Data(object): class Data(object):
@ -16,7 +18,7 @@ class Data(object):
self.model = None self.model = None
self.minimizer = None self.minimizer = None
self.parameter = Parameters() self.parameter = Parameters()
self.para_keys = None self.para_keys: list = []
self.fun_kwargs = {} self.fun_kwargs = {}
def __len__(self): def __len__(self):
@ -68,12 +70,19 @@ class Data(object):
def get_model(self): def get_model(self):
return self.model return self.model
def set_parameter(self, parameter, var=None, ub=None, lb=None, def set_parameter(self,
default_bounds=False, fun_kwargs=None): values: list[float | Parameter],
*,
var: list[bool] = None,
ub: list[float] = None,
lb: list[float] = None,
default_bounds: bool = False,
fun_kwargs: dict = None
):
""" """
Creates parameter for this data. Creates parameter for this data.
If no Model is available, it falls back to the model If no Model is available, it falls back to the model
:param parameter: list of parameters :param values: list of parameters
:param var: list of boolean or boolean; False fixes parameter at given list index. :param var: list of boolean or boolean; False fixes parameter at given list index.
Single value is broadcast to all parameter Single value is broadcast to all parameter
:param ub: list of upper boundaries or float; Single value is broadcast to all parameter. :param ub: list of upper boundaries or float; Single value is broadcast to all parameter.
@ -87,23 +96,46 @@ class Data(object):
model = self.model model = self.model
if model is None: if model is None:
# Data has no unique # Data has no unique
if self.minimizer is None: if self.minimizer is not None:
model = None
else:
model = self.minimizer.fit_model model = self.minimizer.fit_model
self.fun_kwargs.update(model.fun_kwargs)
if model is None: if model is None:
raise ValueError('No model found, please set model before parameters') raise ValueError('No model found, please set model before parameters')
if default_bounds: if len(values) != len(model.params):
raise ValueError('Number of given parameter does not match number of model parameters')
is_parameter = [isinstance(v, Parameter) for v in values]
if all(is_parameter):
for p_i in values:
key = f"p{next(Parameters.parameter_counter)}"
self.parameter.add_parameter(key, p_i)
elif any(is_parameter):
raise ValueError('list of parameter are not all float of Parameter')
else:
if var is None:
var = [True] * len(values)
if lb is None: if lb is None:
if default_bounds:
lb = model.lb lb = model.lb
else:
lb = [None] * len(values)
if ub is None: if ub is None:
if default_bounds:
ub = model.ub ub = model.ub
else:
ub = [None] * len(values)
self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub) arg_names = ['name', 'value', 'var', 'lb', 'ub']
for parameter_arg in zip(model.params, values, var, lb, ub):
self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)})
self.para_keys = list(self.parameter.keys())
self.fun_kwargs.update(model.fun_kwargs)
if fun_kwargs is not None: if fun_kwargs is not None:
self.fun_kwargs.update(fun_kwargs) self.fun_kwargs.update(fun_kwargs)
@ -123,6 +155,18 @@ class Data(object):
else: else:
return [p.value for p in self.minimizer.parameters[self.parameter]] return [p.value for p in self.minimizer.parameters[self.parameter]]
def replace_parameter(self, key: str, parameter: Parameter) -> None:
tobereplaced = None
for k, v in self.parameter.items():
if v.name == parameter.name:
tobereplaced = k
break
if tobereplaced is None:
raise KeyError(f'Global parameter {key} not found in list of parameters')
self.para_keys[self.para_keys.index(tobereplaced)] = key
self.parameter.replace_parameter(tobereplaced, key, parameter)
def cost(self, p): def cost(self, p):
""" """
Cost function :math:`y-f(p, x)` Cost function :math:`y-f(p, x)`

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import warnings import warnings
from itertools import product from itertools import product
@ -21,13 +23,70 @@ class FitAbortException(Exception):
pass pass
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_pars: list[list[str]]):
# replace values
for keys, values in zip(varpars, p):
for data_i in data:
if keys in data_i.parameter.keys():
# TODO move this to scaled_value setter
data_i.parameter[keys].scaled_value = values
data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [values.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.cost(actual_parameters)]
return r
def _cost_scipy(p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
actual_parameters = [data.parameter[keys].value for keys in used_pars]
return data.cost(actual_parameters)
def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0):
for keys, values in zip(varpars, p):
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
actual_parameters = [data.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]):
# replace values
for data_i in data:
_update_parameter(data_i, var_pars, p)
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [values.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]):
for keys, values in zip(varied_keys, parameter):
if keys in data.parameter.keys():
data.parameter[keys].scaled_value = values
data.parameter[keys].namespace[keys] = data.parameter[keys].value
class FitRoutine(object): class FitRoutine(object):
def __init__(self, mode='lsq'): def __init__(self, mode='lsq'):
self.fitmethod = mode self.fitmethod = mode
self.data = [] self.data = []
self.fit_model = None self.fit_model = None
self._no_own_model = [] self._no_own_model = []
self.parameter = Parameters()
self.result = [] self.result = []
self.linked = [] self.linked = []
self._abort = False self._abort = False
@ -81,29 +140,27 @@ class FitRoutine(object):
return self.fit_model return self.fit_model
def set_link_parameter(self, parameter: tuple, replacement: tuple): def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]):
if isinstance(replacement[0], Model): if isinstance(replacement[0], Model):
if replacement[1] not in replacement[0].global_parameter: if replacement[1] not in replacement[0].parameter:
raise KeyError(f'Parameter at pos {replacement[1]} of ' raise KeyError(f'Parameter {replacement[1]} of '
f'model {str(replacement[0])} is not global') f'model {replacement[0]} is not global')
if isinstance(parameter[0], Model): if isinstance(dismissed_param[0], Model):
warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} ' warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} '
f'becomes global with linkage.') f'becomes global with linkage.')
self.linked.append((*parameter, *replacement)) self.linked.append((*dismissed_param, *replacement))
def prepare_links(self): def prepare_links(self):
self._no_own_model = [] self._no_own_model = []
self.parameter = Parameters()
_found_models = {} _found_models = {}
linked_sender = {} linked_sender = {}
for v in self.data: for v in self.data:
linked_sender[v] = set() linked_sender[v] = set()
self.parameter.update(v.parameter.copy())
# set temporaray model # set temporary model
if v.model is None: if v.model is None:
v.model = self.fit_model v.model = self.fit_model
self._no_own_model.append(v) self._no_own_model.append(v)
@ -111,8 +168,6 @@ class FitRoutine(object):
# register model # register model
if v.model not in _found_models: if v.model not in _found_models:
_found_models[v.model] = [] _found_models[v.model] = []
m_param = v.model.parameter.copy()
self.parameter.update(m_param)
_found_models[v.model].append(v) _found_models[v.model].append(v)
@ -120,24 +175,21 @@ class FitRoutine(object):
linked_sender[v.model] = set() linked_sender[v.model] = set()
linked_parameter = {} linked_parameter = {}
for par, par_parm, repl, repl_par in self.linked: for dismiss_model, dismiss_param, replace_model, replace_param in self.linked:
if isinstance(par, Data): linked_sender[replace_model].add(dismiss_model)
if isinstance(repl, Data): linked_sender[replace_model].add(replace_model)
linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par]
else:
linked_parameter[par.para_keys[par_parm]] = repl.global_parameter[repl_par]
else: replace_key = replace_model.parameter.get_key(replace_param)
if isinstance(repl, Data): dismiss_key = dismiss_model.parameter.get_key(dismiss_param)
par.global_parameter[par_parm] = repl.para_keys[repl_par]
else:
par.global_parameter[par_parm] = repl.global_parameter[repl_par]
linked_sender[repl].add(par) if isinstance(replace_model, Data):
linked_sender[par].add(repl) linked_parameter[dismiss_key] = replace_key
else:
p = dismiss_model.set_global_parameter(dismiss_param, replace_key)
p._expr_disp = replace_param
for mm, m_data in _found_models.items(): for mm, m_data in _found_models.items():
if mm.global_parameter: if mm.parameter:
for dd in m_data: for dd in m_data:
linked_sender[mm].add(dd) linked_sender[mm].add(dd)
linked_sender[dd].add(mm) linked_sender[dd].add(mm)
@ -171,13 +223,11 @@ class FitRoutine(object):
def run(self, mode: str = None): def run(self, mode: str = None):
self._abort = False self._abort = False
self.parameter = Parameters()
if mode is None: if mode is None:
mode = self.fitmethod mode = self.fitmethod
fit_groups, linked_parameter = self.prepare_links() fit_groups, linked_parameter = self.prepare_links()
for data_groups in fit_groups: for data_groups in fit_groups:
if len(data_groups) == 1 and not self.linked: if len(data_groups) == 1 and not self.linked:
data = data_groups[0] data = data_groups[0]
@ -208,8 +258,21 @@ class FitRoutine(object):
self.unprep_run() self.unprep_run()
for r in self.result:
r.pprint()
return self.result return self.result
def make_preview(self, x: np.ndarray) -> list[np.ndarray]:
y_pred = []
fit_groups, linked_parameter = self.prepare_links()
for data_groups in fit_groups:
data = data_groups[0]
actual_parameters = [p.value for p in data.parameter.values()]
y_pred.append(data.func(actual_parameters, x))
return y_pred
def _prep_data(self, data): def _prep_data(self, data):
if data.get_model() is None: if data.get_model() is None:
data._model = self.fit_model data._model = self.fit_model
@ -237,22 +300,16 @@ class FitRoutine(object):
var = [] var = []
data_pars = [] data_pars = []
# loopyloop over data that belong to one fit (linked or global) # loopy-loop over data that belong to one fit (linked or global)
for data in data_group: for data in data_group:
actual_pars = []
for i, (p_k, v_k) in enumerate(data.parameter.items()):
p_k_used = p_k
v_k_used = v_k
# is parameter replaced by global parameter? # is parameter replaced by global parameter?
if i in data.model.global_parameter: for k, v in data.model.parameter.items():
p_k_used = data.model.global_parameter[i] data.replace_parameter(k, v)
v_k_used = self.parameter[p_k_used]
# links trump global parameter actual_pars = []
if p_k_used in linked: for i, p_k in enumerate(data.para_keys):
p_k_used = linked[p_k_used] p_k_used = p_k
v_k_used = self.parameter[p_k_used] v_k_used = data.parameter[p_k]
actual_pars.append(p_k_used) actual_pars.append(p_k_used)
# parameter is variable and was not found before as shared parameter # parameter is variable and was not found before as shared parameter
@ -271,48 +328,7 @@ class FitRoutine(object):
d._model = None d._model = None
self._no_own_model = [] self._no_own_model = []
Parameters.reset()
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
def __cost_scipy(self, p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
actual_parameters = [self.parameter[keys].value for keys in used_pars]
return data.cost(actual_parameters)
def __cost_odr(self, p, data, varpars, used_pars):
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
actual_parameters = [self.parameter[keys].value for keys in used_pars]
return data.func(actual_parameters, data.x)
def __cost_scipy_glob(self, p, data, varpars, used_pars):
# replace values
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [self.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.cost(actual_parameters)]
return r
def __cost_odr_glob(self, p, data, varpars, used_pars):
# replace values
for keys, values in zip(varpars, p):
self.parameter[keys].scaled_value = values
r = []
# unpack parameter and calculate y values and concatenate all
for values, p_idx in zip(data, used_pars):
actual_parameters = [self.parameter[keys].value for keys in p_idx]
r = np.r_[r, values.func(actual_parameters, values.x)]
return r
def _least_squares_single(self, data, p0, lb, ub, var): def _least_squares_single(self, data, p0, lb, ub, var):
self.step = 0 self.step = 0
@ -322,7 +338,7 @@ class FitRoutine(object):
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return self.__cost_scipy(p, data, var, data.para_keys) return _cost_scipy(p, data, var, data.para_keys)
with np.errstate(all='ignore'): with np.errstate(all='ignore'):
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
@ -336,7 +352,7 @@ class FitRoutine(object):
self.step += 1 self.step += 1
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return self.__cost_scipy_glob(p, data, var, data_pars) return _cost_scipy_glob(p, data, var, data_pars)
with np.errstate(all='ignore'): with np.errstate(all='ignore'):
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0)) res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
@ -351,7 +367,7 @@ class FitRoutine(object):
self.step += 1 self.step += 1
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return (self.__cost_scipy(p, data, var, data.para_keys)**2).sum() return (_cost_scipy(p, data, var, data.para_keys) ** 2).sum()
with np.errstate(all='ignore'): with np.errstate(all='ignore'):
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
@ -364,7 +380,7 @@ class FitRoutine(object):
self.step += 1 self.step += 1
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return (self.__cost_scipy_glob(p, data, var, data_pars)**2).sum() return (_cost_scipy_glob(p, data, var, data_pars) ** 2).sum()
with np.errstate(all='ignore'): with np.errstate(all='ignore'):
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)], res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
@ -380,13 +396,18 @@ class FitRoutine(object):
self.step += 1 self.step += 1
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr(p, data, var_pars, data.para_keys) return _cost_odr(p, data, var_pars, data.para_keys)
odr_model = odr.Model(func) odr_model = odr.Model(func)
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
def _odr_fit(self, odr_data, odr_model, p0):
o = odr.ODR(odr_data, odr_model, beta0=p0) o = odr.ODR(odr_data, odr_model, beta0=p0)
res = o.run() res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try: try:
corr_inv = np.linalg.inv(corr) corr_inv = np.linalg.inv(corr)
@ -395,16 +416,14 @@ class FitRoutine(object):
partial_corr[np.diag_indices_from(partial_corr)] = 1. partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError: except np.linalg.LinAlgError:
partial_corr = corr partial_corr = corr
return corr, partial_corr, res
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
def _odr_global(self, data, p0, var, data_pars): def _odr_global(self, data, p0, var, data_pars):
def func(p, _): def func(p, _):
self.step += 1 self.step += 1
if self._abort: if self._abort:
raise FitAbortException(f'Fit aborted by user') raise FitAbortException(f'Fit aborted by user')
return self.__cost_odr_glob(p, data, var, data_pars) return _cost_odr_glob(p, data, var, data_pars)
x = [] x = []
y = [] y = []
@ -415,17 +434,7 @@ class FitRoutine(object):
odr_data = odr.Data(x, y) odr_data = odr.Data(x, y)
odr_model = odr.Model(func) odr_model = odr.Model(func)
o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var) corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
res = o.run()
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
try:
corr_inv = np.linalg.inv(corr)
corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv)))
partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix
partial_corr[np.diag_indices_from(partial_corr)] = 1.
except np.linalg.LinAlgError:
partial_corr = corr
for v, var_pars_k in zip(data, data_pars): for v, var_pars_k in zip(data, data_pars):
self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)), self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)),
@ -439,15 +448,17 @@ class FitRoutine(object):
# update parameter values # update parameter values
for keys, p_value, err_value in zip(var_pars, p, err): for keys, p_value, err_value in zip(var_pars, p, err):
self.parameter[keys].scaled_value = p_value if keys in data.parameter.keys():
self.parameter[keys].scaled_error = err_value data.parameter[keys].scaled_value = p_value
data.parameter[keys].scaled_error = err_value
data.parameter[keys].namespace[keys] = data.parameter[keys].value
combinations = list(product(var_pars, var_pars)) combinations = list(product(var_pars, var_pars))
actual_parameters = [] actual_parameters = []
corr_idx = [] corr_idx = []
for i, p_i in enumerate(used_pars): for i, p_i in enumerate(used_pars):
actual_parameters.append(self.parameter[p_i]) actual_parameters.append(data.parameter[p_i])
for j, p_j in enumerate(used_pars): for j, p_j in enumerate(used_pars):
try: try:
# find the position of the parameter combinations # find the position of the parameter combinations
@ -508,3 +519,4 @@ class FitRoutine(object):
partial_corr = corr partial_corr = corr
return _err, corr, partial_corr return _err, corr, partial_corr

View File

@ -6,7 +6,7 @@ from typing import Sized
from numpy import inf from numpy import inf
from ._meta import MultiModel from ._meta import MultiModel
from .parameter import Parameters from .parameter import Parameters, Parameter
class Model(object): class Model(object):
@ -25,7 +25,6 @@ class Model(object):
self.ub = [i if i is not None else inf for i in self.ub] self.ub = [i if i is not None else inf for i in self.ub]
self.parameter = Parameters() self.parameter = Parameters()
self.global_parameter = {}
self.is_complex = None self.is_complex = None
self._complex_part = False self._complex_part = False
@ -80,23 +79,33 @@ class Model(object):
self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items() self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items()
if v.default is not inspect.Parameter.empty} if v.default is not inspect.Parameter.empty}
def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False): def set_global_parameter(self,
if idx is None: key: str | Parameter,
self.parameter = Parameters() value: float | str = None,
self.global_parameter = {} *,
return var: bool = None,
lb: float = None,
ub: float = None,
default_bounds: bool = False,
) -> Parameter:
if isinstance(key, Parameter):
p = key
key = f'p{next(Parameters.parameter_counter)}'
self.parameter.add_parameter(key, p)
else:
idx = [self.params.index(key)]
if default_bounds: if default_bounds:
if lb is None: if lb is None:
lb = [self.lb[i] for i in idx] lb = [self.lb[i] for i in idx]
if ub is None: if ub is None:
ub = [self.lb[i] for i in idx] ub = [self.lb[i] for i in idx]
gp = self.parameter.add_parameter(p, var=var, lb=lb, ub=ub) p = self.parameter.add(key, value, var=var, lb=lb, ub=ub)
for k, v in zip(idx, gp): p.is_global = True
self.global_parameter[k] = v
return gp return p
@staticmethod @staticmethod
def _prep(param_len, val): def _prep(param_len, val):

View File

@ -1,94 +1,134 @@
from __future__ import annotations from __future__ import annotations
from numbers import Number import re
from itertools import count from itertools import count
from io import StringIO
import numpy as np import numpy as np
class Parameters(dict): class Parameters(dict):
count = count() parameter_counter = count()
# is one global namespace a good idea?
namespace: dict = {}
def __str__(self): def __init__(self):
return 'Parameters:\n' + '\n'.join([str(k)+': '+str(v) for k, v in self.items()]) super().__init__()
self._mapping: dict = {}
def __getitem__(self, item): def __str__(self) -> str:
if isinstance(item, (list, tuple, np.ndarray)): return 'Parameters:\n' + '\n'.join([f'{k}: {v}' for k, v in self.items()])
values = []
for item_i in item: def __getitem__(self, item) -> Parameter:
values.append(super().__getitem__(item_i)) if item in self._mapping:
return values return super().__getitem__(self._mapping[item])
else: else:
return super().__getitem__(item) return super().__getitem__(item)
def __setitem__(self, key, value):
self.add_parameter(key, value)
def __contains__(self, item):
for v in self.values():
if item == v.name:
return True
return False
def add(self,
name: str,
value: str | float | int = None,
*,
var: bool = True,
lb: float = -np.inf, ub: float = np.inf) -> Parameter:
par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub)
key = f'p{next(Parameters.parameter_counter)}'
self.add_parameter(key, par)
return par
def add_parameter(self, key: str, parameter: Parameter):
self._mapping[parameter.name] = key
super().__setitem__(key, parameter)
parameter.eval_allowed = False
self.namespace[key] = parameter.value
parameter.namespace = self.namespace
parameter.eval_allowed = True
self.update_namespace()
def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter):
self.add_parameter(key_in, parameter)
for k, v in self._mapping.items():
if v == key_out:
self._mapping[k] = key_in
break
if key_out in self.namespace:
del self.namespace[key_out]
def fix(self):
for v in self.keys():
v._value = v.value
v.namespace = {}
@staticmethod @staticmethod
def _prep_bounds(val, p_len: int) -> list: def reset():
# helper function to ensure that bounds and variable are of parameter shape Parameters.namespace = {}
if isinstance(val, (Number, bool)) or val is None:
return [val] * p_len
elif len(val) == p_len: def get_key(self, name: str) -> str | None:
return val
elif len(val) == 1:
return [val[0]] * p_len
else:
raise ValueError('Input {} has wrong dimensions'.format(val))
def add_parameter(self, param, var=None, lb=None, ub=None):
if isinstance(param, Number):
param = [param]
p_len = len(param)
# make list if only single value is given
var = self._prep_bounds(var, p_len)
lb = self._prep_bounds(lb, p_len)
ub = self._prep_bounds(ub, p_len)
new_keys = []
for i in range(p_len):
new_idx = next(self.count)
new_keys.append(new_idx)
self[new_idx] = Parameter(param[i], var=var[i], lb=lb[i], ub=ub[i])
return new_keys
def copy(self):
p = Parameters()
for k, v in self.items(): for k, v in self.items():
p[k] = Parameter(v.value, var=v.var, lb=v.lb, ub=v.ub) if name == v.name:
return k
if len(p) == 0: return
return p
max_k = max(p.keys())
c = next(p.count)
while c < max_k:
c = next(p.count)
return p
def get_state(self): def get_state(self):
return {k: v.get_state() for k, v in self.items()} return {k: v.get_state() for k, v in self.items()}
def update_namespace(self):
for p in self.values():
try:
p.value
except NameError:
expression = p._expr_disp
for n, k in self._mapping.items():
expression, num_replaced = re.subn(re.escape(n), k, expression)
p._expr = expression
class Parameter: class Parameter:
""" """
Container for one parameter Container for one parameter
""" """
__slots__ = ['name', 'value', 'error', 'init_val', 'var', 'lb', 'ub', 'scale', 'function']
def __init__(self, value: float, var: bool = True, lb: float = -np.inf, ub: float = np.inf): # TODO Parameter should know its own key
self.lb = lb if lb is not None else -np.inf def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf):
self.ub = ub if ub is not None else np.inf self._value: float | None = None
self.var: bool = bool(var) if var is not None else True
self.error: None | float = None if self.var is False else 0.0
self.name: str = name
self.function: str = ""
if self.lb <= value <= self.ub: self.lb: None | float = lb if lb is not None else -np.inf
self.value = value self.ub: float | None = ub if ub is not None else np.inf
self.namespace: dict = {}
self.eval_allowed: bool = True
self._expr: None | str = None
self._expr_disp: None | str = None
self.is_global = False
if isinstance(value, str):
self._expr_disp = value
self._expr = value
self.var = False
else:
if self.lb <= value <= self.ub:
self._value = value
else: else:
print(value, self.lb, self.ub)
raise ValueError('Value of parameter is outside bounds') raise ValueError('Value of parameter is outside bounds')
self.init_val = value self.init_val = value
@ -100,25 +140,31 @@ class Parameter:
if self.scale == 0: if self.scale == 0:
self.scale = 1. self.scale = 1.
self.var = bool(var) if var is not None else True def __str__(self) -> str:
self.error = None if self.var is False else 0.0 start = StringIO()
self.name = ''
self.function = ''
def __str__(self):
start = ''
if self.name: if self.name:
if self.function: if self.function:
start = f'{self.name} ({self.function}): ' start.write(f"{self.name} ({self.function})")
else: else:
start = self.name + ': ' start.write(self.name)
if self.is_global:
start.write("*")
start.write(": ")
if self.var: if self.var:
return start + f'{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}' start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}")
else: else:
return start + f'{self.value:} (fixed)' start.write(f"{self.value:.4g}")
if self._expr is None:
start.write(" (fixed)")
else:
start.write(f" (calc: {self._expr_disp})")
def __add__(self, other: Parameter | float) -> float: return start.getvalue()
def __add__(self, other: Parameter | float | int) -> float:
if isinstance(other, (float, int)): if isinstance(other, (float, int)):
return self.value + other return self.value + other
elif isinstance(other, Parameter): elif isinstance(other, Parameter):
@ -128,30 +174,39 @@ class Parameter:
return self.__add__(other) return self.__add__(other)
@property @property
def scaled_value(self): def scaled_value(self) -> float:
return self.value / self.scale return self.value / self.scale
@scaled_value.setter @scaled_value.setter
def scaled_value(self, value): def scaled_value(self, value: float) -> None:
self.value = value * self.scale self._value = value * self.scale
@property @property
def scaled_error(self): def value(self) -> float | None:
if self.error is None: if self._value is not None:
return self.error return self._value
else:
if self._expr is not None and self.eval_allowed:
return eval(self._expr, {}, self.namespace)
return
@property
def scaled_error(self) -> None | float:
if self.error is not None:
return self.error / self.scale return self.error / self.scale
return
@scaled_error.setter @scaled_error.setter
def scaled_error(self, value): def scaled_error(self, value) -> None:
self.error = value * self.scale self.error = value * self.scale
def get_state(self): def get_state(self) -> dict:
return {slot: getattr(self, slot) for slot in self.__slots__} return {slot: getattr(self, slot) for slot in self.__slots__}
@staticmethod @staticmethod
def set_state(state: dict): def set_state(state: dict) -> Parameter:
par = Parameter(state.pop('value')) par = Parameter(state.pop('value'))
for k, v in state.items(): for k, v in state.items():
setattr(par, k, v) setattr(par, k, v)
@ -159,9 +214,28 @@ class Parameter:
return par return par
@property @property
def full_name(self): def full_name(self) -> str:
name = self.name name = self.name
if self.function: if self.function:
name += ' (' + self.function + ')' name += f" ({self.function})"
return name return name
def copy(self) -> Parameter:
if self._expr:
val = self._expr_disp
else:
val = self._value
para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub)
para_copy._expr = self._expr
para_copy.namespace = self.namespace
para_copy.is_global = self.is_global
para_copy.error = self.error
para_copy.function = self.function
return para_copy
def fix(self):
self._value = self.value
self.namespace = {}

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import re import re
from collections import OrderedDict from collections import OrderedDict
from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -186,7 +187,7 @@ class FitResult(Points):
nice_name = m.group(1) nice_name = m.group(1)
if func_number in split_funcs: if func_number in split_funcs:
nice_func = split_funcs[func_number] nice_func = split_funcs[func_number]
pvalue.fix()
pvalue.name = nice_name pvalue.name = nice_name
pvalue.function = nice_func pvalue.function = nice_func
parameter_dic[pname] = pvalue parameter_dic[pname] = pvalue
@ -223,27 +224,30 @@ class FitResult(Points):
return self.nobs-self.nvar return self.nobs-self.nvar
def pprint(self, statistics=True, correlations=True): def pprint(self, statistics=True, correlations=True):
print('Fit result:') s = StringIO()
print(' model :', self.name) s.write('Fit result:\n')
print(' #data :', self.nobs) s.write(f' model : {self.name}\n')
print(' #var :', self.nvar) s.write(f' #data : {self.nobs}\n')
print('\nParameter') s.write(f' #var : {self.nvar}\n')
print(self.parameter_string()) s.write('\nParameter\n')
s.write(self.parameter_string())
if statistics: if statistics:
print('Statistics') s.write('\nStatistics\n')
for k, v in self.statistics.items(): for k, v in self.statistics.items():
print(f' {k} : {v:.4f}') s.write(f' {k} : {v:.4f}\n')
if correlations and self.correlation is not None: if correlations and self.correlation is not None:
print('\nCorrelation (partial corr.)') s.write('\nCorrelation (partial corr.)\n')
print(self._correlation_string()) s.write(self._correlation_string())
print() s.write('\n')
print(s.getvalue())
def parameter_string(self): def parameter_string(self):
ret_val = '' ret_val = ''
for pval in self.parameter.values(): for pkey, pval in self.parameter.items():
ret_val += convert(str(pval), old='tex', new='str') + '\n' ret_val += convert(str(pval), old='tex', new='str') + '\n'
if self.fun_kwargs: if self.fun_kwargs:
@ -255,9 +259,7 @@ class FitResult(Points):
def _correlation_string(self): def _correlation_string(self):
ret_val = '' ret_val = ''
for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list(): for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list():
ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'), ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n"
convert(p_j, old='tex', new='str'),
corr_ij, pcorr_ij)
return ret_val return ret_val
def correlation_list(self, limit=0.1): def correlation_list(self, limit=0.1):

View File

@ -35,8 +35,8 @@ class Gaussian:
class Lorentzian: class Lorentzian:
type = 'Spectrum' type = 'Spectrum'
name = 'Lorentzian' name = 'Lorentzian'
equation = 'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}' equation = r'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}'
params = ['A', '\mu', 'w', 'A_{0}'] params = ['A', r'\mu', 'w', 'A_{0}']
ext_params = None ext_params = None
bounds = [(0, None), (None, None), (0, None), (None, None)] bounds = [(0, None), (None, None), (0, None), (None, None)]
@ -62,9 +62,9 @@ class Lorentzian:
class PseudoVoigt: class PseudoVoigt:
type = 'Spectrum' type = 'Spectrum'
name = 'Pseudo Voigt' name = 'Pseudo Voigt'
equation = 'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \ equation = r'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \
'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}' r'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}'
params = ['A', 'R', '\mu', 'w', 'A_{0}'] params = ['A', 'R', r'\mu', 'w', 'A_{0}']
ext_params = None ext_params = None
bounds = [(0, None), (0, 1), (None, None), (0, None)] bounds = [(0, None), (0, 1), (None, None), (0, None)]

View File

@ -7,7 +7,7 @@
<x>0</x> <x>0</x>
<y>0</y> <y>0</y>
<width>365</width> <width>365</width>
<height>78</height> <height>66</height>
</rect> </rect>
</property> </property>
<property name="sizePolicy"> <property name="sizePolicy">
@ -62,7 +62,7 @@
<item> <item>
<widget class="LineEdit" name="parameter_line"> <widget class="LineEdit" name="parameter_line">
<property name="sizePolicy"> <property name="sizePolicy">
<sizepolicy hsizetype="Fixed" vsizetype="Fixed"> <sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed">
<horstretch>0</horstretch> <horstretch>0</horstretch>
<verstretch>0</verstretch> <verstretch>0</verstretch>
</sizepolicy> </sizepolicy>
@ -78,19 +78,6 @@
</property> </property>
</widget> </widget>
</item> </item>
<item>
<spacer name="horizontalSpacer">
<property name="orientation">
<enum>Qt::Horizontal</enum>
</property>
<property name="sizeHint" stdset="0">
<size>
<width>40</width>
<height>20</height>
</size>
</property>
</spacer>
</item>
<item> <item>
<widget class="QCheckBox" name="fixed_check"> <widget class="QCheckBox" name="fixed_check">
<property name="text"> <property name="text">
@ -105,19 +92,6 @@
</property> </property>
</widget> </widget>
</item> </item>
<item>
<widget class="QToolButton" name="toolButton">
<property name="text">
<string/>
</property>
<property name="popupMode">
<enum>QToolButton::InstantPopup</enum>
</property>
<property name="arrowType">
<enum>Qt::RightArrow</enum>
</property>
</widget>
</item>
</layout> </layout>
</item> </item>
<item> <item>