forked from IPKM/nmreval
Merge branch 'fit_constraints'
# Conflicts: # src/gui_qt/main/management.py
This commit is contained in:
commit
04037d6b4d
@ -1,10 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Form implementation generated from reading ui file 'resources/_ui/fitmodelwidget.ui'
|
||||
# Form implementation generated from reading ui file 'src/resources/_ui/fitmodelwidget.ui'
|
||||
#
|
||||
# Created by: PyQt5 UI code generator 5.12.3
|
||||
# Created by: PyQt5 UI code generator 5.15.9
|
||||
#
|
||||
# WARNING! All changes made in this file will be lost!
|
||||
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
|
||||
# run again. Do not edit this file unless you know what you are doing.
|
||||
|
||||
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
@ -13,7 +14,7 @@ from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
class Ui_FitParameter(object):
|
||||
def setupUi(self, FitParameter):
|
||||
FitParameter.setObjectName("FitParameter")
|
||||
FitParameter.resize(365, 78)
|
||||
FitParameter.resize(365, 66)
|
||||
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.MinimumExpanding)
|
||||
sizePolicy.setHorizontalStretch(0)
|
||||
sizePolicy.setVerticalStretch(0)
|
||||
@ -36,7 +37,7 @@ class Ui_FitParameter(object):
|
||||
self.parametername.setObjectName("parametername")
|
||||
self.horizontalLayout_2.addWidget(self.parametername)
|
||||
self.parameter_line = LineEdit(FitParameter)
|
||||
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
|
||||
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
|
||||
sizePolicy.setHorizontalStretch(0)
|
||||
sizePolicy.setVerticalStretch(0)
|
||||
sizePolicy.setHeightForWidth(self.parameter_line.sizePolicy().hasHeightForWidth())
|
||||
@ -44,20 +45,12 @@ class Ui_FitParameter(object):
|
||||
self.parameter_line.setText("")
|
||||
self.parameter_line.setObjectName("parameter_line")
|
||||
self.horizontalLayout_2.addWidget(self.parameter_line)
|
||||
spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
|
||||
self.horizontalLayout_2.addItem(spacerItem)
|
||||
self.fixed_check = QtWidgets.QCheckBox(FitParameter)
|
||||
self.fixed_check.setObjectName("fixed_check")
|
||||
self.horizontalLayout_2.addWidget(self.fixed_check)
|
||||
self.global_checkbox = QtWidgets.QCheckBox(FitParameter)
|
||||
self.global_checkbox.setObjectName("global_checkbox")
|
||||
self.horizontalLayout_2.addWidget(self.global_checkbox)
|
||||
self.toolButton = QtWidgets.QToolButton(FitParameter)
|
||||
self.toolButton.setText("")
|
||||
self.toolButton.setPopupMode(QtWidgets.QToolButton.InstantPopup)
|
||||
self.toolButton.setArrowType(QtCore.Qt.RightArrow)
|
||||
self.toolButton.setObjectName("toolButton")
|
||||
self.horizontalLayout_2.addWidget(self.toolButton)
|
||||
self.verticalLayout.addLayout(self.horizontalLayout_2)
|
||||
self.frame = QtWidgets.QFrame(FitParameter)
|
||||
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum)
|
||||
|
@ -8,6 +8,7 @@ from pyqtgraph import mkPen
|
||||
|
||||
from nmreval.data.points import Points
|
||||
from nmreval.data.signals import Signal
|
||||
from nmreval.lib.logger import logger
|
||||
from nmreval.utils.text import convert
|
||||
from nmreval.data.bds import BDS
|
||||
from nmreval.data.dsc import DSC
|
||||
@ -356,7 +357,7 @@ class ExperimentContainer(QtCore.QObject):
|
||||
elif mode in ['imag', 'all'] and self.plot_imag is not None:
|
||||
self.plot_imag.set_symbol(symbol=symbol, size=size, color=color)
|
||||
else:
|
||||
print('Updating symbol failed for ' + str(self.id))
|
||||
logger.warning(f'Updating symbol failed for {self.id}')
|
||||
|
||||
def setLine(self, *, width=None, style=None, color=None, mode='real'):
|
||||
if mode in ['real', 'all']:
|
||||
@ -368,7 +369,7 @@ class ExperimentContainer(QtCore.QObject):
|
||||
elif mode in ['imag', 'all'] and self.plot_imag is not None:
|
||||
self.plot_imag.set_line(width=width, style=style, color=color)
|
||||
else:
|
||||
print('Updating line failed for ' + str(self.id))
|
||||
logger.warning(f'Updating line failed for {self.id}')
|
||||
|
||||
def update_property(self, key1: str, key2: str, value: Any):
|
||||
keykey = key2.split()
|
||||
|
@ -1,3 +1,4 @@
|
||||
from nmreval.lib.logger import logger
|
||||
from nmreval.math import apodization
|
||||
from nmreval.lib.importer import find_models
|
||||
from nmreval.utils.text import convert
|
||||
@ -67,7 +68,7 @@ class EditSignalWidget(QtWidgets.QWidget, Ui_Form):
|
||||
self.do_something.emit(sender, (ph0, ph1, pvt))
|
||||
|
||||
else:
|
||||
print('You should never reach this by accident.')
|
||||
logger.warning(f'You should never reach this by accident, invalid sender {sender!r}')
|
||||
|
||||
@QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged')
|
||||
def change_apodization(self, index):
|
||||
|
@ -19,19 +19,20 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
|
||||
super().__init__(parent)
|
||||
self.setupUi(self)
|
||||
|
||||
self.parametername.setText(label + ' ')
|
||||
self.name = label
|
||||
|
||||
self.parametername.setText(convert(label) + ' ')
|
||||
|
||||
validator = QtGui.QDoubleValidator()
|
||||
self.parameter_line.setValidator(validator)
|
||||
self.parameter_line.setText('1')
|
||||
self.parameter_line.setMaximumWidth(240)
|
||||
self.lineEdit.setMaximumWidth(60)
|
||||
self.lineEdit_2.setMaximumWidth(60)
|
||||
self.parameter_line.setMaximumWidth(160)
|
||||
self.lineEdit.setMaximumWidth(100)
|
||||
self.lineEdit_2.setMaximumWidth(100)
|
||||
|
||||
self.label_3.setText(f'< {label} <')
|
||||
self.label_3.setText(f'< {convert(label)} <')
|
||||
|
||||
self.checkBox.stateChanged.connect(self.enableBounds)
|
||||
self.global_checkbox.stateChanged.connect(lambda: self.state_changed.emit())
|
||||
self.parameter_line.editingFinished.connect(self.update_parameter)
|
||||
self.parameter_line.values_requested.connect(lambda: self.value_requested.emit(self))
|
||||
self.parameter_line.replace_single_values.connect(lambda: self.replace_single_value.emit(None))
|
||||
self.parameter_line.editingFinished.connect(lambda: self.value_changed.emit(self.parameter_line.text()))
|
||||
@ -40,18 +41,12 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
|
||||
if fixed:
|
||||
self.fixed_check.hide()
|
||||
|
||||
self.menu = QtWidgets.QMenu(self)
|
||||
self.add_links()
|
||||
|
||||
self.is_linked = None
|
||||
self.parameter_pos = None
|
||||
self.func_idx = None
|
||||
|
||||
self._linetext = '1'
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return convert(self.parametername.text().strip(), old='html', new='str')
|
||||
self.menu = QtWidgets.QMenu(self)
|
||||
|
||||
def set_parameter_string(self, p: str):
|
||||
self.parameter_line.setText(p)
|
||||
@ -71,11 +66,6 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
|
||||
|
||||
def set_parameter(self, p: float | None, bds: tuple[float, float, bool] = None,
|
||||
fixed: bool = None, glob: bool = None):
|
||||
if p is None:
|
||||
# bad hack: linked parameter return (None, linked parameter)
|
||||
# if p is None -> parameter is linked to argument given by bds
|
||||
self.link_parameter(linkto=bds)
|
||||
else:
|
||||
ptext = f'{p:.4g}'
|
||||
|
||||
self.set_parameter_string(ptext)
|
||||
@ -90,19 +80,10 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
|
||||
self.global_checkbox.setCheckState(QtCore.Qt.Checked if glob else QtCore.Qt.Unchecked)
|
||||
|
||||
def get_parameter(self):
|
||||
if self.is_linked:
|
||||
try:
|
||||
p = float(self._linetext)
|
||||
except ValueError:
|
||||
p = 1.0
|
||||
else:
|
||||
try:
|
||||
p = float(self.parameter_line.text().replace(',', '.'))
|
||||
except ValueError:
|
||||
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value',
|
||||
f'{self.parametername.text()} contains invalid values',
|
||||
QtWidgets.QMessageBox.Cancel)
|
||||
return None
|
||||
p = self.parameter_line.text().replace(',', '.')
|
||||
|
||||
if self.checkBox.isChecked():
|
||||
try:
|
||||
@ -119,75 +100,27 @@ class FitModelWidget(QtWidgets.QWidget, Ui_FitParameter):
|
||||
|
||||
bounds = (lb, rb)
|
||||
|
||||
return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked(), self.is_linked
|
||||
return p, bounds, not self.fixed_check.isChecked(), self.global_checkbox.isChecked()
|
||||
|
||||
@QtCore.pyqtSlot(bool)
|
||||
def set_fixed(self, state: bool):
|
||||
# self.global_checkbox.setVisible(not state)
|
||||
self.frame.setVisible(not state)
|
||||
|
||||
def add_links(self, parameter: dict = None):
|
||||
if parameter is None:
|
||||
parameter = {}
|
||||
self.menu.clear()
|
||||
|
||||
ac = QtWidgets.QAction('Link to...', self)
|
||||
ac.triggered.connect(self.link_parameter)
|
||||
self.menu.addAction(ac)
|
||||
|
||||
for model_key, model_funcs in parameter.items():
|
||||
m = QtWidgets.QMenu('Model ' + model_key, self)
|
||||
for func_name, func_params in model_funcs.items():
|
||||
m2 = QtWidgets.QMenu(func_name, m)
|
||||
for p_name, idx in func_params:
|
||||
ac = QtWidgets.QAction(p_name, m2)
|
||||
ac.setData((model_key, *idx))
|
||||
ac.triggered.connect(self.link_parameter)
|
||||
m2.addAction(ac)
|
||||
m.addMenu(m2)
|
||||
self.menu.addMenu(m)
|
||||
|
||||
self.toolButton.setMenu(self.menu)
|
||||
|
||||
@QtCore.pyqtSlot()
|
||||
def link_parameter(self, linkto=None):
|
||||
if linkto is None:
|
||||
action = self.sender()
|
||||
else:
|
||||
action = False
|
||||
for m in self.menu.actions():
|
||||
if m.menu():
|
||||
for a in m.menu().actions():
|
||||
if a.data() == linkto:
|
||||
action = a
|
||||
break
|
||||
if action:
|
||||
break
|
||||
|
||||
if (self.func_idx, self.parameter_pos) == action.data():
|
||||
return
|
||||
def update_parameter(self):
|
||||
new_value = self.parameter_line.text()
|
||||
if not new_value:
|
||||
self.parameter_line.setText('1')
|
||||
|
||||
try:
|
||||
new_text = f'Linked to {action.parentWidget().title()}.{action.text()}'
|
||||
self._linetext = self.parameter_line.text()
|
||||
self.parameter_line.setText(new_text)
|
||||
self.parameter_line.setEnabled(False)
|
||||
self.global_checkbox.hide()
|
||||
self.global_checkbox.blockSignals(True)
|
||||
self.global_checkbox.setCheckState(QtCore.Qt.Checked)
|
||||
self.global_checkbox.blockSignals(False)
|
||||
self.frame.hide()
|
||||
self.is_linked = action.data()
|
||||
float(new_value)
|
||||
is_text = False
|
||||
except ValueError:
|
||||
is_text = True
|
||||
self.global_checkbox.setCheckState(False)
|
||||
|
||||
except AttributeError:
|
||||
self.parameter_line.setText(self._linetext)
|
||||
self.parameter_line.setEnabled(True)
|
||||
if self.fixed_check.isEnabled():
|
||||
self.global_checkbox.show()
|
||||
self.frame.show()
|
||||
self.is_linked = None
|
||||
|
||||
self.state_changed.emit()
|
||||
self.set_fixed(is_text)
|
||||
|
||||
|
||||
class QSaveModelDialog(QtWidgets.QDialog, Ui_SaveDialog):
|
||||
@ -282,8 +215,17 @@ class FitModelTree(QtWidgets.QTreeWidget):
|
||||
idx = item.data(0, self.counterRole)
|
||||
self.itemRemoved.emit(idx)
|
||||
|
||||
def add_function(self, idx: int, cnt: int, op: int, name: str, color: QtGui.QColor | str | tuple,
|
||||
parent: QtWidgets.QTreeWidgetItem = None, children: list = None, active: bool = True, **kwargs):
|
||||
def add_function(self,
|
||||
idx: int,
|
||||
cnt: int,
|
||||
op: int,
|
||||
name: str,
|
||||
color: QtGui.QColor | str | tuple,
|
||||
parent: QtWidgets.QTreeWidgetItem = None,
|
||||
children: list = None,
|
||||
active: bool = True,
|
||||
param_names: list[str] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Add function to tree and dictionary of functions.
|
||||
"""
|
||||
@ -298,6 +240,10 @@ class FitModelTree(QtWidgets.QTreeWidget):
|
||||
it.setData(0, self.counterRole, cnt)
|
||||
it.setData(0, self.operatorRole, op)
|
||||
it.setText(0, name)
|
||||
if param_names is not None:
|
||||
it.setToolTip(0,
|
||||
'Parameter names:\n' +
|
||||
'\n'.join(f'{pn}({cnt})' for pn in param_names))
|
||||
it.setForeground(0, QtGui.QBrush(color))
|
||||
|
||||
it.setIcon(0, get_icon(self.icons[op]))
|
||||
|
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from nmreval.fit.parameter import Parameter
|
||||
from nmreval.utils.text import convert
|
||||
|
||||
from ..Qt import QtWidgets, QtCore, QtGui
|
||||
@ -62,8 +65,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
self.glob_values = [1] * len(func.params)
|
||||
|
||||
for k, v in enumerate(func.params):
|
||||
name = convert(v)
|
||||
widgt = FitModelWidget(label=name, parent=self.scrollwidget)
|
||||
widgt = FitModelWidget(label=v, parent=self.scrollwidget)
|
||||
widgt.parameter_pos = k
|
||||
widgt.func_idx = idx
|
||||
try:
|
||||
@ -83,7 +85,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
self.global_parameter.append(widgt)
|
||||
self.scrollwidget.layout().addWidget(widgt)
|
||||
|
||||
widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2)
|
||||
widgt2 = ParameterSingleWidget(name=v, parent=self.scrollwidget2)
|
||||
widgt2.valueChanged.connect(self.change_single_parameter)
|
||||
widgt2.removeSingleValue.connect(self.change_single_parameter)
|
||||
widgt2.installEventFilter(self)
|
||||
@ -115,20 +117,22 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
self.scrollwidget.layout().addStretch(1)
|
||||
self.scrollwidget2.layout().addStretch(1)
|
||||
|
||||
def set_links(self, parameter):
|
||||
for w in self.global_parameter:
|
||||
if isinstance(w, FitModelWidget):
|
||||
w.add_links(parameter)
|
||||
# def set_links(self, parameter):
|
||||
# for w in self.global_parameter:
|
||||
# if isinstance(w, FitModelWidget):
|
||||
# w.add_links(parameter)
|
||||
|
||||
@QtCore.pyqtSlot(str)
|
||||
def change_global_parameter(self, value: str, idx: int = None):
|
||||
if idx is None:
|
||||
idx = self.global_parameter.index(self.sender())
|
||||
|
||||
self.glob_values[idx] = float(value)
|
||||
# self.glob_values[idx] = float(value)
|
||||
self.glob_values[idx] = value
|
||||
if self.data_values[self.comboBox.currentData()][idx] is None:
|
||||
self.data_parameter[idx].blockSignals(True)
|
||||
self.data_parameter[idx].value = float(value)
|
||||
# self.data_parameter[idx].value = float(value)
|
||||
self.data_parameter[idx].value = value
|
||||
self.data_parameter[idx].blockSignals(False)
|
||||
|
||||
@QtCore.pyqtSlot(str, object)
|
||||
@ -171,7 +175,7 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
# disable single parameter if it is set global, enable if global is unset
|
||||
widget = self.sender()
|
||||
idx = self.global_parameter.index(widget)
|
||||
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None)
|
||||
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked)
|
||||
self.data_parameter[idx].setEnabled(enable)
|
||||
|
||||
def select_next_preview(self, direction):
|
||||
@ -204,64 +208,50 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
if sid not in self.data_values:
|
||||
self.data_values[sid] = [None] * len(self.data_parameter)
|
||||
|
||||
def get_parameter(self, use_func=None):
|
||||
def get_parameter(self, use_func=None) -> tuple[dict, list]:
|
||||
bds = []
|
||||
is_global = []
|
||||
is_fixed = []
|
||||
globs = []
|
||||
is_linked = []
|
||||
param_general = []
|
||||
|
||||
for g in self.global_parameter:
|
||||
if isinstance(g, FitModelWidget):
|
||||
p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter()
|
||||
p_i, bds_i, fixed_i, global_i = g.get_parameter()
|
||||
parameter_i = Parameter(name=g.name, value=p_i, lb=bds_i[0], ub=bds_i[1], var=fixed_i)
|
||||
param_general.append(parameter_i)
|
||||
|
||||
globs.append(p_i)
|
||||
bds.append(bds_i)
|
||||
is_fixed.append(fixed_i)
|
||||
is_global.append(global_i)
|
||||
is_linked.append(link_i)
|
||||
|
||||
lb, ub = list(zip(*bds))
|
||||
|
||||
data_parameter = {}
|
||||
if use_func is None:
|
||||
use_func = list(self.data_values.keys())
|
||||
|
||||
global_p = None
|
||||
for sid, parameter in self.data_values.items():
|
||||
if sid not in use_func:
|
||||
continue
|
||||
|
||||
kw_p = {}
|
||||
p = []
|
||||
if global_p is None:
|
||||
global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []}
|
||||
|
||||
for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)):
|
||||
if isinstance(g, FitModelWidget):
|
||||
if (p_i is None) or is_global[i]:
|
||||
p.append(globs[i])
|
||||
if is_global[i]:
|
||||
if i not in global_p['idx']:
|
||||
global_p['p'].append(globs[i])
|
||||
global_p['idx'].append(i)
|
||||
global_p['var'].append(is_fixed[i])
|
||||
global_p['ub'].append(ub[i])
|
||||
global_p['lb'].append(lb[i])
|
||||
# set has no oen value
|
||||
p.append(param_general[i].copy())
|
||||
else:
|
||||
p.append(p_i)
|
||||
|
||||
lb, ub = bds[i]
|
||||
try:
|
||||
if p[i] > ub[i]:
|
||||
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
|
||||
if not (lb < p_i < ub):
|
||||
raise ValueError(f'Parameter {g.name} is outside bounds ({lb}, {ub})')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if p[i] < lb[i]:
|
||||
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
|
||||
except TypeError:
|
||||
pass
|
||||
# create Parameter
|
||||
p.append(
|
||||
Parameter(name=g.name, value=p_i, lb=lb, ub=ub, var=is_fixed[i])
|
||||
)
|
||||
|
||||
else:
|
||||
if p_i is None:
|
||||
@ -273,7 +263,15 @@ class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
|
||||
|
||||
data_parameter[sid] = (p, kw_p)
|
||||
|
||||
return data_parameter, lb, ub, is_fixed, global_p, is_linked
|
||||
global_parameter = []
|
||||
for param, global_flag in zip(param_general, is_global):
|
||||
if global_flag:
|
||||
param.is_global = True
|
||||
global_parameter.append(param)
|
||||
else:
|
||||
global_parameter.append(None)
|
||||
|
||||
return data_parameter, global_parameter
|
||||
|
||||
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
|
||||
num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter))
|
||||
@ -304,12 +302,12 @@ class ParameterSingleWidget(QtWidgets.QWidget):
|
||||
|
||||
self._init_ui()
|
||||
|
||||
self._name = name
|
||||
self.name = name
|
||||
self.label.setText(convert(name))
|
||||
self.label.setToolTip('If this is bold then this parameter is only for this data. '
|
||||
'Otherwise, the general parameter is used and displayed')
|
||||
|
||||
self.value_line.setValidator(QtGui.QDoubleValidator())
|
||||
# self.value_line.setValidator(QtGui.QDoubleValidator())
|
||||
self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0)
|
||||
self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit())
|
||||
|
||||
@ -343,9 +341,10 @@ class ParameterSingleWidget(QtWidgets.QWidget):
|
||||
|
||||
@value.setter
|
||||
def value(self, val):
|
||||
self.value_line.setText(f'{float(val):.5g}')
|
||||
# self.value_line.setText(f'{float(val):.5g}')
|
||||
self.value_line.setText(f'{val}')
|
||||
|
||||
def show_as_local_parameter(self, is_local):
|
||||
def show_as_local_parameter(self, is_local: bool):
|
||||
if is_local:
|
||||
self.label.setStyleSheet('font-weight: bold;')
|
||||
else:
|
||||
|
@ -128,7 +128,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
|
||||
|
||||
self.newFunction.emit(idx, cnt)
|
||||
|
||||
self.add_function(idx, cnt, op, name, col)
|
||||
self.add_function(idx, cnt, op, name, col, param_names=self.functions[idx].params)
|
||||
|
||||
def add_function(self, idx: int, cnt: int, op: int,
|
||||
name: str, color: str | tuple[float, float, float] | BaseColor, **kwargs):
|
||||
@ -141,6 +141,7 @@ class QFunctionWidget(QtWidgets.QWidget, Ui_Form):
|
||||
qcolor = QtGui.QColor.fromRgbF(*color)
|
||||
else:
|
||||
qcolor = QtGui.QColor(color)
|
||||
|
||||
self.functree.add_function(idx, cnt, op, name, qcolor, **kwargs)
|
||||
|
||||
f = self.functions[idx]
|
||||
|
@ -9,6 +9,9 @@ import numpy as np
|
||||
from pyqtgraph import mkPen
|
||||
|
||||
from nmreval.fit._meta import MultiModel, ModelFactory
|
||||
from nmreval.fit.data import Data
|
||||
from nmreval.fit.model import Model
|
||||
from nmreval.fit.parameter import Parameters
|
||||
from nmreval.fit.result import FitResult
|
||||
|
||||
from .fit_forms import FitTableWidget
|
||||
@ -116,7 +119,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
|
||||
# collect parameter names etc. to allow linkage
|
||||
self._func_list[self._current_model] = self.functionwidget.get_parameter_list()
|
||||
dialog.set_links(self._func_list)
|
||||
# dialog.set_links(self._func_list)
|
||||
|
||||
# show same tab (general parameter/Data parameter)
|
||||
tab_idx = 0
|
||||
@ -219,57 +222,49 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
|
||||
def _prepare(self, model: list, function_use: list = None,
|
||||
parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]:
|
||||
|
||||
if parameter is None:
|
||||
parameter = {'parameter': {}, 'lb': (), 'ub': (), 'var': [],
|
||||
'glob': {'idx': [], 'p': [], 'var': [], 'lb': [], 'ub': []},
|
||||
'links': [], 'color': []}
|
||||
parameter = {
|
||||
'data_parameter': {},
|
||||
'global_parameter': [],
|
||||
'links': [],
|
||||
'color': [],
|
||||
}
|
||||
|
||||
for i, f in enumerate(model):
|
||||
if not f['active']:
|
||||
continue
|
||||
|
||||
try:
|
||||
p, lb, ub, var, glob, links = self.param_widgets[f['cnt']].get_parameter(function_use)
|
||||
p, glob = self.param_widgets[f['cnt']].get_parameter(function_use)
|
||||
except ValueError as e:
|
||||
_ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e),
|
||||
QtWidgets.QMessageBox.Ok)
|
||||
return {}, -1
|
||||
|
||||
p_len = len(parameter['lb'])
|
||||
|
||||
parameter['lb'] += lb
|
||||
parameter['ub'] += ub
|
||||
parameter['var'] += var
|
||||
parameter['links'] += links
|
||||
parameter['color'] += [f['color']]
|
||||
parameter['color'].append(f['color'])
|
||||
parameter['global_parameter'].extend(glob)
|
||||
|
||||
cnt = f['cnt']
|
||||
|
||||
for p_k, v_k in p.items():
|
||||
if add_idx:
|
||||
kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()}
|
||||
else:
|
||||
kw_k = v_k[1]
|
||||
|
||||
if p_k in parameter['parameter']:
|
||||
params, kw = parameter['parameter'][p_k]
|
||||
if p_k in parameter['data_parameter']:
|
||||
params, kw = parameter['data_parameter'][p_k]
|
||||
params += v_k[0]
|
||||
kw.update(kw_k)
|
||||
else:
|
||||
parameter['parameter'][p_k] = (v_k[0], kw_k)
|
||||
|
||||
for g_k, g_v in glob.items():
|
||||
if g_k != 'idx':
|
||||
parameter['glob'][g_k] += g_v
|
||||
else:
|
||||
parameter['glob']['idx'] += [idx_i + p_len for idx_i in g_v]
|
||||
parameter['data_parameter'][p_k] = (v_k[0], kw_k)
|
||||
|
||||
if add_idx:
|
||||
cnt += 1
|
||||
|
||||
if f['children']:
|
||||
# recurse for children
|
||||
child_parameter, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt)
|
||||
_, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt)
|
||||
|
||||
return parameter, cnt
|
||||
|
||||
@ -280,30 +275,43 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
data = self.data_table.collect_data(default=self.default_combobox.currentData())
|
||||
|
||||
func_dict = {}
|
||||
for k, mod in self.models.items():
|
||||
func, order, param_len = ModelFactory.create_from_list(mod)
|
||||
for model_name, model_parameter in self.models.items():
|
||||
func, order, param_len = ModelFactory.create_from_list(model_parameter)
|
||||
|
||||
if func is None:
|
||||
continue
|
||||
|
||||
if k in data:
|
||||
parameter, _ = self._prepare(mod, function_use=data[k], add_idx=isinstance(func, MultiModel))
|
||||
func = Model(func)
|
||||
|
||||
if model_name in data:
|
||||
parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=isinstance(func, MultiModel))
|
||||
|
||||
if parameter is None:
|
||||
return
|
||||
|
||||
for (data_parameter, _) in parameter['data_parameter'].values():
|
||||
for pname, param in zip(func.params, data_parameter):
|
||||
param.name = pname
|
||||
|
||||
if self._complex[model_name] is not None:
|
||||
for p_k, p_v in parameter['data_parameter'].items():
|
||||
p_v[1].update({'complex_mode': self._complex[model_name]})
|
||||
parameter['data_parameter'][p_k] = p_v[0], p_v[1]
|
||||
|
||||
for pname, param_value in zip(func.params, parameter['global_parameter']):
|
||||
if param_value is not None:
|
||||
param_value.name = pname
|
||||
func.set_global_parameter(param_value)
|
||||
|
||||
parameter['func'] = func
|
||||
parameter['order'] = order
|
||||
parameter['len'] = param_len
|
||||
parameter['complex'] = self._complex[k]
|
||||
if self._complex[k] is not None:
|
||||
for p_k, p_v in parameter['parameter'].items():
|
||||
p_v[1].update({'complex_mode': self._complex[k]})
|
||||
parameter['parameter'][p_k] = p_v[0], p_v[1]
|
||||
parameter['complex'] = self._complex[model_name]
|
||||
|
||||
func_dict[k] = parameter
|
||||
func_dict[model_name] = parameter
|
||||
|
||||
replaceable = []
|
||||
for k, v in func_dict.items():
|
||||
for model_name, v in func_dict.items():
|
||||
for i, link_i in enumerate(v['links']):
|
||||
if link_i is None:
|
||||
continue
|
||||
@ -334,7 +342,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
QtWidgets.QMessageBox.Ok)
|
||||
return
|
||||
|
||||
replaceable.append((k, i, rep_model, repl_idx))
|
||||
replaceable.append((model_name, i, rep_model, repl_idx))
|
||||
|
||||
replace_value = None
|
||||
for p_k in f['parameter'].values():
|
||||
@ -412,31 +420,37 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
def make_previews(self, x, models_parameters: dict):
|
||||
self.preview_lines = []
|
||||
|
||||
# needed to create namespace
|
||||
param_dict = Parameters()
|
||||
|
||||
cnt = 0
|
||||
for model in models_parameters.values():
|
||||
f = model['func']
|
||||
for parameter_list in model['data_parameter'].values():
|
||||
for i, p_value in enumerate(parameter_list[0]):
|
||||
p_value.name = f.params[i]
|
||||
param_dict.add_parameter(f'a{cnt}', p_value)
|
||||
cnt += 1
|
||||
|
||||
for k, model in models_parameters.items():
|
||||
f = model['func']
|
||||
is_complex = self._complex[k]
|
||||
|
||||
parameters = model['parameter']
|
||||
parameters = model['data_parameter']
|
||||
color = model['color']
|
||||
|
||||
seen_parameter = []
|
||||
|
||||
for p, kwargs in parameters.values():
|
||||
if (p, kwargs) in seen_parameter:
|
||||
# plot only previews with different parameter
|
||||
continue
|
||||
|
||||
seen_parameter.append((p, kwargs))
|
||||
p_value = [pp.value for pp in p]
|
||||
|
||||
if is_complex is not None:
|
||||
y = f.func(x, *p, complex_mode=is_complex, **kwargs)
|
||||
y = f.func(x, *p_value, complex_mode=is_complex, **kwargs)
|
||||
if np.iscomplexobj(y):
|
||||
self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3)))
|
||||
self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3)))
|
||||
else:
|
||||
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
|
||||
else:
|
||||
y = f.func(x, *p, **kwargs)
|
||||
y = f.func(x, *p_value, **kwargs)
|
||||
self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3)))
|
||||
|
||||
if isinstance(f, MultiModel):
|
||||
@ -444,7 +458,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
if is_complex is not None:
|
||||
sub_kwargs.update({'complex_mode': is_complex})
|
||||
|
||||
for i, s in enumerate(f.subs(x, *p, **sub_kwargs)):
|
||||
for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)):
|
||||
pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i]))
|
||||
if np.iscomplexobj(s):
|
||||
self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i))
|
||||
@ -452,15 +466,17 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog):
|
||||
else:
|
||||
self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i))
|
||||
|
||||
param_dict.clear()
|
||||
|
||||
return self.preview_lines
|
||||
|
||||
def set_parameter(self, parameter: dict[str, FitResult]):
|
||||
# which data uses which model
|
||||
data = self.data_table.collect_data(default=self.default_combobox.currentData())
|
||||
|
||||
for fitted_model, fitted_data in data.items():
|
||||
glob_fit_parameter = []
|
||||
|
||||
for fitted_model, fitted_data in data.items():
|
||||
for fit_id, fit_curve in parameter.items():
|
||||
if fit_id in fitted_data:
|
||||
fit_parameter = list(fit_curve.parameter.values())
|
||||
|
@ -138,9 +138,7 @@ class DrawingsWidget(QtWidgets.QWidget, Ui_Form):
|
||||
graph_id = self.graph_comboBox.currentData()
|
||||
current_lines = self.lines[graph_id]
|
||||
|
||||
print(remove_rows)
|
||||
for i in reversed(remove_rows):
|
||||
print(i)
|
||||
self.tableWidget.removeRow(i)
|
||||
self.line_deleted.emit(current_lines[i], graph_id)
|
||||
|
||||
|
@ -27,7 +27,6 @@ class MdiAreaTile(QtWidgets.QMdiArea):
|
||||
pos = QtCore.QPoint(0, 0)
|
||||
|
||||
for win in window_list:
|
||||
print(win.minimumSize())
|
||||
win.setGeometry(rect)
|
||||
win.move(pos)
|
||||
|
||||
|
@ -1,110 +0,0 @@
|
||||
import os.path
|
||||
import json
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
import random
|
||||
|
||||
from ..Qt import QtGui, QtCore, QtWidgets
|
||||
from .._py.pokemon import Ui_Dialog
|
||||
|
||||
random.seed()
|
||||
|
||||
|
||||
class QPokemon(QtWidgets.QDialog, Ui_Dialog):
|
||||
def __init__(self, number=None, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
self.setupUi(self)
|
||||
self._js = json.load(open(os.path.join(path_to_module, 'utils', 'pokemon.json'), 'r'), encoding='UTF-8')
|
||||
self._id = 0
|
||||
|
||||
if number is not None and number in range(1, len(self._js)+1):
|
||||
poke_nr = f'{number:03d}'
|
||||
self._id = number
|
||||
else:
|
||||
poke_nr = f'{random.randint(1, len(self._js)):03d}'
|
||||
self._id = int(poke_nr)
|
||||
|
||||
self._pokemon = None
|
||||
self.show_pokemon(poke_nr)
|
||||
self.label_15.linkActivated.connect(lambda x: webbrowser.open(x))
|
||||
|
||||
self.buttonBox.clicked.connect(self.randomize)
|
||||
self.next_button.clicked.connect(self.show_next)
|
||||
self.prev_button.clicked.connect(self.show_prev)
|
||||
|
||||
def show_pokemon(self, nr):
|
||||
self._pokemon = self._js[nr]
|
||||
self.setWindowTitle('Pokémon: ' + self._pokemon['Deutsch'])
|
||||
|
||||
for i in range(self.tabWidget.count(), -1, -1):
|
||||
print('i', self.tabWidget.count(), i)
|
||||
try:
|
||||
self.tabWidget.widget(i).deleteLater()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
for n, img in self._pokemon['Bilder']:
|
||||
w = QtWidgets.QWidget()
|
||||
vl = QtWidgets.QVBoxLayout()
|
||||
l = QtWidgets.QLabel(self)
|
||||
l.setAlignment(QtCore.Qt.AlignHCenter)
|
||||
pixmap = QtGui.QPixmap()
|
||||
|
||||
try:
|
||||
pixmap.loadFromData(urllib.request.urlopen(img, timeout=0.5).read())
|
||||
except IOError:
|
||||
l.setText(n)
|
||||
else:
|
||||
sc_pixmap = pixmap.scaled(256, 256, QtCore.Qt.KeepAspectRatio)
|
||||
l.setPixmap(sc_pixmap)
|
||||
|
||||
vl.addWidget(l)
|
||||
w.setLayout(vl)
|
||||
self.tabWidget.addTab(w, n)
|
||||
|
||||
if len(self._pokemon['Bilder']) <= 1:
|
||||
self.tabWidget.tabBar().setVisible(False)
|
||||
else:
|
||||
self.tabWidget.tabBar().setVisible(True)
|
||||
self.tabWidget.adjustSize()
|
||||
|
||||
self.name.clear()
|
||||
keys = ['National-Dex', 'Kategorie', 'Typ', 'Größe', 'Gewicht', 'Farbe', 'Link']
|
||||
label_list = [self.pokedex_nr, self.category, self.poketype, self.weight, self.height, self.color, self.info]
|
||||
for (k, label) in zip(keys, label_list):
|
||||
v = self._pokemon[k]
|
||||
if isinstance(v, list):
|
||||
v = os.path.join('', *v)
|
||||
|
||||
if k == 'Link':
|
||||
v = '<a href={}>{}</a>'.format(v, v)
|
||||
|
||||
label.setText(v)
|
||||
|
||||
for k in ['Deutsch', 'Japanisch', 'Englisch', 'Französisch']:
|
||||
v = self._pokemon[k]
|
||||
self.name.addItem(k + ': ' + v)
|
||||
|
||||
self.adjustSize()
|
||||
|
||||
def randomize(self, idd):
|
||||
if idd.text() == 'Retry':
|
||||
new_number = f'{random.randint(1, len(self._js)):03d}'
|
||||
self._id = int(new_number)
|
||||
self.show_pokemon(new_number)
|
||||
else:
|
||||
self.close()
|
||||
|
||||
def show_next(self):
|
||||
new_number = self._id + 1
|
||||
if new_number > len(self._js):
|
||||
new_number = 1
|
||||
self._id = new_number
|
||||
self.show_pokemon(f'{new_number:03d}')
|
||||
|
||||
def show_prev(self):
|
||||
new_number = self._id - 1
|
||||
if new_number == 0:
|
||||
new_number = len(self._js)
|
||||
self._id = new_number
|
||||
self.show_pokemon(f'{new_number:03d}')
|
@ -441,7 +441,7 @@ class UpperManagement(QtCore.QObject):
|
||||
# all-encompassing error catch
|
||||
try:
|
||||
for model_id, model_p in parameter.items():
|
||||
m = Model(model_p['func'])
|
||||
m = model_p['func']
|
||||
models[model_id] = m
|
||||
|
||||
m_complex = model_p['complex']
|
||||
@ -450,13 +450,16 @@ class UpperManagement(QtCore.QObject):
|
||||
# iterate over order of set id in active order and access parameter inside loop
|
||||
# instead of directly looping
|
||||
try:
|
||||
list_ids = list(model_p['parameter'].keys())
|
||||
list_ids = list(model_p['data_parameter'].keys())
|
||||
set_order = [self.active_id.index(i) for i in list_ids]
|
||||
except ValueError as e:
|
||||
raise Exception('Getting order failed') from e
|
||||
|
||||
for pos in set_order:
|
||||
set_id = list_ids[pos]
|
||||
|
||||
data_i = self.data[set_id]
|
||||
set_params = model_p['data_parameter'][set_id]
|
||||
try:
|
||||
data_i = self.data[set_id]
|
||||
except KeyError as e:
|
||||
@ -499,18 +502,12 @@ class UpperManagement(QtCore.QObject):
|
||||
|
||||
d.set_model(m)
|
||||
try:
|
||||
d.set_parameter(set_params[0], var=model_p['var'],
|
||||
lb=model_p['lb'], ub=model_p['ub'],
|
||||
fun_kwargs=set_params[1])
|
||||
d.set_parameter(set_params[0], fun_kwargs=set_params[1])
|
||||
except Exception as e:
|
||||
raise Exception('Setting parameter failed') from e
|
||||
|
||||
self.fitter.add_data(d)
|
||||
|
||||
model_globs = model_p['glob']
|
||||
if model_globs:
|
||||
m.set_global_parameter(**model_p['glob'])
|
||||
|
||||
for links_i in links:
|
||||
self.fitter.set_link_parameter((models[links_i[0]], links_i[1]),
|
||||
(models[links_i[2]], links_i[3]))
|
||||
@ -1170,7 +1167,6 @@ class UpperManagement(QtCore.QObject):
|
||||
|
||||
@QtCore.pyqtSlot(dict)
|
||||
def calc_relaxation(self, opts: dict):
|
||||
|
||||
params = opts['pts']
|
||||
if len(params) == 4:
|
||||
if params[3]:
|
||||
|
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .model import Model
|
||||
from .parameter import Parameters
|
||||
from .parameter import Parameters, Parameter
|
||||
|
||||
|
||||
class Data(object):
|
||||
@ -16,7 +18,7 @@ class Data(object):
|
||||
self.model = None
|
||||
self.minimizer = None
|
||||
self.parameter = Parameters()
|
||||
self.para_keys = None
|
||||
self.para_keys: list = []
|
||||
self.fun_kwargs = {}
|
||||
|
||||
def __len__(self):
|
||||
@ -68,12 +70,19 @@ class Data(object):
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def set_parameter(self, parameter, var=None, ub=None, lb=None,
|
||||
default_bounds=False, fun_kwargs=None):
|
||||
def set_parameter(self,
|
||||
values: list[float | Parameter],
|
||||
*,
|
||||
var: list[bool] = None,
|
||||
ub: list[float] = None,
|
||||
lb: list[float] = None,
|
||||
default_bounds: bool = False,
|
||||
fun_kwargs: dict = None
|
||||
):
|
||||
"""
|
||||
Creates parameter for this data.
|
||||
If no Model is available, it falls back to the model
|
||||
:param parameter: list of parameters
|
||||
:param values: list of parameters
|
||||
:param var: list of boolean or boolean; False fixes parameter at given list index.
|
||||
Single value is broadcast to all parameter
|
||||
:param ub: list of upper boundaries or float; Single value is broadcast to all parameter.
|
||||
@ -87,23 +96,46 @@ class Data(object):
|
||||
model = self.model
|
||||
if model is None:
|
||||
# Data has no unique
|
||||
if self.minimizer is None:
|
||||
model = None
|
||||
else:
|
||||
if self.minimizer is not None:
|
||||
model = self.minimizer.fit_model
|
||||
self.fun_kwargs.update(model.fun_kwargs)
|
||||
|
||||
if model is None:
|
||||
raise ValueError('No model found, please set model before parameters')
|
||||
|
||||
if default_bounds:
|
||||
if len(values) != len(model.params):
|
||||
raise ValueError('Number of given parameter does not match number of model parameters')
|
||||
|
||||
is_parameter = [isinstance(v, Parameter) for v in values]
|
||||
if all(is_parameter):
|
||||
for p_i in values:
|
||||
key = f"p{next(Parameters.parameter_counter)}"
|
||||
self.parameter.add_parameter(key, p_i)
|
||||
elif any(is_parameter):
|
||||
raise ValueError('list of parameter are not all float of Parameter')
|
||||
|
||||
else:
|
||||
if var is None:
|
||||
var = [True] * len(values)
|
||||
|
||||
if lb is None:
|
||||
if default_bounds:
|
||||
lb = model.lb
|
||||
else:
|
||||
lb = [None] * len(values)
|
||||
|
||||
if ub is None:
|
||||
if default_bounds:
|
||||
ub = model.ub
|
||||
else:
|
||||
ub = [None] * len(values)
|
||||
|
||||
self.para_keys = self.parameter.add_parameter(parameter, var=var, lb=lb, ub=ub)
|
||||
arg_names = ['name', 'value', 'var', 'lb', 'ub']
|
||||
for parameter_arg in zip(model.params, values, var, lb, ub):
|
||||
self.parameter.add(**{arg_name: arg_value for arg_name, arg_value in zip(arg_names, parameter_arg)})
|
||||
|
||||
self.para_keys = list(self.parameter.keys())
|
||||
|
||||
self.fun_kwargs.update(model.fun_kwargs)
|
||||
if fun_kwargs is not None:
|
||||
self.fun_kwargs.update(fun_kwargs)
|
||||
|
||||
@ -123,6 +155,18 @@ class Data(object):
|
||||
else:
|
||||
return [p.value for p in self.minimizer.parameters[self.parameter]]
|
||||
|
||||
def replace_parameter(self, key: str, parameter: Parameter) -> None:
|
||||
tobereplaced = None
|
||||
for k, v in self.parameter.items():
|
||||
if v.name == parameter.name:
|
||||
tobereplaced = k
|
||||
break
|
||||
|
||||
if tobereplaced is None:
|
||||
raise KeyError(f'Global parameter {key} not found in list of parameters')
|
||||
self.para_keys[self.para_keys.index(tobereplaced)] = key
|
||||
self.parameter.replace_parameter(tobereplaced, key, parameter)
|
||||
|
||||
def cost(self, p):
|
||||
"""
|
||||
Cost function :math:`y-f(p, x)`
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from itertools import product
|
||||
|
||||
@ -21,13 +23,70 @@ class FitAbortException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
|
||||
def _cost_scipy_glob(p: list[float], data: list[Data], varpars: list[str], used_pars: list[list[str]]):
|
||||
# replace values
|
||||
for keys, values in zip(varpars, p):
|
||||
for data_i in data:
|
||||
if keys in data_i.parameter.keys():
|
||||
# TODO move this to scaled_value setter
|
||||
data_i.parameter[keys].scaled_value = values
|
||||
data_i.parameter[keys].namespace[keys] = data_i.parameter[keys].value
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [values.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.cost(actual_parameters)]
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _cost_scipy(p, data, varpars, used_pars):
|
||||
for keys, values in zip(varpars, p):
|
||||
data.parameter[keys].scaled_value = values
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
actual_parameters = [data.parameter[keys].value for keys in used_pars]
|
||||
return data.cost(actual_parameters)
|
||||
|
||||
|
||||
def _cost_odr(p: list[float], data: Data, varpars: list[str], used_pars: list[str], fitmode: int=0):
|
||||
for keys, values in zip(varpars, p):
|
||||
data.parameter[keys].scaled_value = values
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
actual_parameters = [data.parameter[keys].value for keys in used_pars]
|
||||
|
||||
return data.func(actual_parameters, data.x)
|
||||
|
||||
|
||||
def _cost_odr_glob(p: list[float], data: list[Data], var_pars: list[str], used_pars: list[str]):
|
||||
# replace values
|
||||
for data_i in data:
|
||||
_update_parameter(data_i, var_pars, p)
|
||||
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [values.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.func(actual_parameters, values.x)]
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _update_parameter(data: Data, varied_keys: list[str], parameter: list[float]):
|
||||
for keys, values in zip(varied_keys, parameter):
|
||||
if keys in data.parameter.keys():
|
||||
data.parameter[keys].scaled_value = values
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
|
||||
class FitRoutine(object):
|
||||
def __init__(self, mode='lsq'):
|
||||
self.fitmethod = mode
|
||||
self.data = []
|
||||
self.fit_model = None
|
||||
self._no_own_model = []
|
||||
self.parameter = Parameters()
|
||||
self.result = []
|
||||
self.linked = []
|
||||
self._abort = False
|
||||
@ -81,29 +140,27 @@ class FitRoutine(object):
|
||||
|
||||
return self.fit_model
|
||||
|
||||
def set_link_parameter(self, parameter: tuple, replacement: tuple):
|
||||
def set_link_parameter(self, dismissed_param: tuple[Model | Data, str], replacement: tuple[Model, str]):
|
||||
if isinstance(replacement[0], Model):
|
||||
if replacement[1] not in replacement[0].global_parameter:
|
||||
raise KeyError(f'Parameter at pos {replacement[1]} of '
|
||||
f'model {str(replacement[0])} is not global')
|
||||
if replacement[1] not in replacement[0].parameter:
|
||||
raise KeyError(f'Parameter {replacement[1]} of '
|
||||
f'model {replacement[0]} is not global')
|
||||
|
||||
if isinstance(parameter[0], Model):
|
||||
warnings.warn(f'Replaced parameter at pos {parameter[1]} in {str(parameter[0])} '
|
||||
if isinstance(dismissed_param[0], Model):
|
||||
warnings.warn(f'Replaced parameter {dismissed_param[1]} in {dismissed_param[0]} '
|
||||
f'becomes global with linkage.')
|
||||
|
||||
self.linked.append((*parameter, *replacement))
|
||||
self.linked.append((*dismissed_param, *replacement))
|
||||
|
||||
def prepare_links(self):
|
||||
self._no_own_model = []
|
||||
self.parameter = Parameters()
|
||||
_found_models = {}
|
||||
linked_sender = {}
|
||||
|
||||
for v in self.data:
|
||||
linked_sender[v] = set()
|
||||
self.parameter.update(v.parameter.copy())
|
||||
|
||||
# set temporaray model
|
||||
# set temporary model
|
||||
if v.model is None:
|
||||
v.model = self.fit_model
|
||||
self._no_own_model.append(v)
|
||||
@ -111,8 +168,6 @@ class FitRoutine(object):
|
||||
# register model
|
||||
if v.model not in _found_models:
|
||||
_found_models[v.model] = []
|
||||
m_param = v.model.parameter.copy()
|
||||
self.parameter.update(m_param)
|
||||
|
||||
_found_models[v.model].append(v)
|
||||
|
||||
@ -120,24 +175,21 @@ class FitRoutine(object):
|
||||
linked_sender[v.model] = set()
|
||||
|
||||
linked_parameter = {}
|
||||
for par, par_parm, repl, repl_par in self.linked:
|
||||
if isinstance(par, Data):
|
||||
if isinstance(repl, Data):
|
||||
linked_parameter[par.para_keys[par_parm]] = repl.para_keys[repl_par]
|
||||
else:
|
||||
linked_parameter[par.para_keys[par_parm]] = repl.global_parameter[repl_par]
|
||||
for dismiss_model, dismiss_param, replace_model, replace_param in self.linked:
|
||||
linked_sender[replace_model].add(dismiss_model)
|
||||
linked_sender[replace_model].add(replace_model)
|
||||
|
||||
else:
|
||||
if isinstance(repl, Data):
|
||||
par.global_parameter[par_parm] = repl.para_keys[repl_par]
|
||||
else:
|
||||
par.global_parameter[par_parm] = repl.global_parameter[repl_par]
|
||||
replace_key = replace_model.parameter.get_key(replace_param)
|
||||
dismiss_key = dismiss_model.parameter.get_key(dismiss_param)
|
||||
|
||||
linked_sender[repl].add(par)
|
||||
linked_sender[par].add(repl)
|
||||
if isinstance(replace_model, Data):
|
||||
linked_parameter[dismiss_key] = replace_key
|
||||
else:
|
||||
p = dismiss_model.set_global_parameter(dismiss_param, replace_key)
|
||||
p._expr_disp = replace_param
|
||||
|
||||
for mm, m_data in _found_models.items():
|
||||
if mm.global_parameter:
|
||||
if mm.parameter:
|
||||
for dd in m_data:
|
||||
linked_sender[mm].add(dd)
|
||||
linked_sender[dd].add(mm)
|
||||
@ -169,15 +221,13 @@ class FitRoutine(object):
|
||||
logger.info('Fit aborted by user')
|
||||
self._abort = True
|
||||
|
||||
def run(self, mode: str=None):
|
||||
def run(self, mode: str = None):
|
||||
self._abort = False
|
||||
self.parameter = Parameters()
|
||||
|
||||
if mode is None:
|
||||
mode = self.fitmethod
|
||||
|
||||
fit_groups, linked_parameter = self.prepare_links()
|
||||
|
||||
for data_groups in fit_groups:
|
||||
if len(data_groups) == 1 and not self.linked:
|
||||
data = data_groups[0]
|
||||
@ -208,8 +258,21 @@ class FitRoutine(object):
|
||||
|
||||
self.unprep_run()
|
||||
|
||||
for r in self.result:
|
||||
r.pprint()
|
||||
|
||||
return self.result
|
||||
|
||||
def make_preview(self, x: np.ndarray) -> list[np.ndarray]:
|
||||
y_pred = []
|
||||
fit_groups, linked_parameter = self.prepare_links()
|
||||
for data_groups in fit_groups:
|
||||
data = data_groups[0]
|
||||
actual_parameters = [p.value for p in data.parameter.values()]
|
||||
y_pred.append(data.func(actual_parameters, x))
|
||||
|
||||
return y_pred
|
||||
|
||||
def _prep_data(self, data):
|
||||
if data.get_model() is None:
|
||||
data._model = self.fit_model
|
||||
@ -237,22 +300,16 @@ class FitRoutine(object):
|
||||
var = []
|
||||
data_pars = []
|
||||
|
||||
# loopyloop over data that belong to one fit (linked or global)
|
||||
# loopy-loop over data that belong to one fit (linked or global)
|
||||
for data in data_group:
|
||||
actual_pars = []
|
||||
for i, (p_k, v_k) in enumerate(data.parameter.items()):
|
||||
p_k_used = p_k
|
||||
v_k_used = v_k
|
||||
|
||||
# is parameter replaced by global parameter?
|
||||
if i in data.model.global_parameter:
|
||||
p_k_used = data.model.global_parameter[i]
|
||||
v_k_used = self.parameter[p_k_used]
|
||||
for k, v in data.model.parameter.items():
|
||||
data.replace_parameter(k, v)
|
||||
|
||||
# links trump global parameter
|
||||
if p_k_used in linked:
|
||||
p_k_used = linked[p_k_used]
|
||||
v_k_used = self.parameter[p_k_used]
|
||||
actual_pars = []
|
||||
for i, p_k in enumerate(data.para_keys):
|
||||
p_k_used = p_k
|
||||
v_k_used = data.parameter[p_k]
|
||||
|
||||
actual_pars.append(p_k_used)
|
||||
# parameter is variable and was not found before as shared parameter
|
||||
@ -271,48 +328,7 @@ class FitRoutine(object):
|
||||
d._model = None
|
||||
|
||||
self._no_own_model = []
|
||||
|
||||
# COST FUNCTIONS: f(x) - y (least_square, minimize), and f(x) (ODR)
|
||||
def __cost_scipy(self, p, data, varpars, used_pars):
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
actual_parameters = [self.parameter[keys].value for keys in used_pars]
|
||||
return data.cost(actual_parameters)
|
||||
|
||||
def __cost_odr(self, p, data, varpars, used_pars):
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
actual_parameters = [self.parameter[keys].value for keys in used_pars]
|
||||
|
||||
return data.func(actual_parameters, data.x)
|
||||
|
||||
def __cost_scipy_glob(self, p, data, varpars, used_pars):
|
||||
# replace values
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [self.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.cost(actual_parameters)]
|
||||
|
||||
return r
|
||||
|
||||
def __cost_odr_glob(self, p, data, varpars, used_pars):
|
||||
# replace values
|
||||
for keys, values in zip(varpars, p):
|
||||
self.parameter[keys].scaled_value = values
|
||||
|
||||
r = []
|
||||
# unpack parameter and calculate y values and concatenate all
|
||||
for values, p_idx in zip(data, used_pars):
|
||||
actual_parameters = [self.parameter[keys].value for keys in p_idx]
|
||||
r = np.r_[r, values.func(actual_parameters, values.x)]
|
||||
|
||||
return r
|
||||
Parameters.reset()
|
||||
|
||||
def _least_squares_single(self, data, p0, lb, ub, var):
|
||||
self.step = 0
|
||||
@ -322,7 +338,7 @@ class FitRoutine(object):
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
|
||||
return self.__cost_scipy(p, data, var, data.para_keys)
|
||||
return _cost_scipy(p, data, var, data.para_keys)
|
||||
|
||||
with np.errstate(all='ignore'):
|
||||
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
|
||||
@ -336,7 +352,7 @@ class FitRoutine(object):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return self.__cost_scipy_glob(p, data, var, data_pars)
|
||||
return _cost_scipy_glob(p, data, var, data_pars)
|
||||
|
||||
with np.errstate(all='ignore'):
|
||||
res = optimize.least_squares(cost, p0, bounds=(lb, ub), max_nfev=500 * len(p0))
|
||||
@ -351,7 +367,7 @@ class FitRoutine(object):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return (self.__cost_scipy(p, data, var, data.para_keys)**2).sum()
|
||||
return (_cost_scipy(p, data, var, data.para_keys) ** 2).sum()
|
||||
|
||||
with np.errstate(all='ignore'):
|
||||
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
|
||||
@ -364,7 +380,7 @@ class FitRoutine(object):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return (self.__cost_scipy_glob(p, data, var, data_pars)**2).sum()
|
||||
return (_cost_scipy_glob(p, data, var, data_pars) ** 2).sum()
|
||||
|
||||
with np.errstate(all='ignore'):
|
||||
res = optimize.minimize(cost, p0, bounds=[(b1, b2) for (b1, b2) in zip(lb, ub)],
|
||||
@ -380,13 +396,18 @@ class FitRoutine(object):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return self.__cost_odr(p, data, var_pars, data.para_keys)
|
||||
return _cost_odr(p, data, var_pars, data.para_keys)
|
||||
|
||||
odr_model = odr.Model(func)
|
||||
|
||||
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
|
||||
|
||||
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
|
||||
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
|
||||
|
||||
def _odr_fit(self, odr_data, odr_model, p0):
|
||||
o = odr.ODR(odr_data, odr_model, beta0=p0)
|
||||
res = o.run()
|
||||
|
||||
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
|
||||
try:
|
||||
corr_inv = np.linalg.inv(corr)
|
||||
@ -395,16 +416,14 @@ class FitRoutine(object):
|
||||
partial_corr[np.diag_indices_from(partial_corr)] = 1.
|
||||
except np.linalg.LinAlgError:
|
||||
partial_corr = corr
|
||||
|
||||
self.make_results(data, res.beta, var_pars, data.para_keys, (len(data), len(p0)),
|
||||
err=res.sd_beta, corr=corr, partial_corr=partial_corr)
|
||||
return corr, partial_corr, res
|
||||
|
||||
def _odr_global(self, data, p0, var, data_pars):
|
||||
def func(p, _):
|
||||
self.step += 1
|
||||
if self._abort:
|
||||
raise FitAbortException(f'Fit aborted by user')
|
||||
return self.__cost_odr_glob(p, data, var, data_pars)
|
||||
return _cost_odr_glob(p, data, var, data_pars)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
@ -415,17 +434,7 @@ class FitRoutine(object):
|
||||
odr_data = odr.Data(x, y)
|
||||
odr_model = odr.Model(func)
|
||||
|
||||
o = odr.ODR(odr_data, odr_model, beta0=p0, ifixb=var)
|
||||
res = o.run()
|
||||
|
||||
corr = res.cov_beta / (res.sd_beta[:, None] * res.sd_beta[None, :]) * res.res_var
|
||||
try:
|
||||
corr_inv = np.linalg.inv(corr)
|
||||
corr_inv_diag = np.diag(np.sqrt(1 / np.diag(corr_inv)))
|
||||
partial_corr = -1. * np.dot(np.dot(corr_inv_diag, corr_inv), corr_inv_diag) # Partial correlation matrix
|
||||
partial_corr[np.diag_indices_from(partial_corr)] = 1.
|
||||
except np.linalg.LinAlgError:
|
||||
partial_corr = corr
|
||||
corr, partial_corr, res = self._odr_fit(odr_data, odr_model, p0)
|
||||
|
||||
for v, var_pars_k in zip(data, data_pars):
|
||||
self.make_results(v, res.beta, var, var_pars_k, (sum(len(d) for d in data), len(p0)),
|
||||
@ -439,15 +448,17 @@ class FitRoutine(object):
|
||||
|
||||
# update parameter values
|
||||
for keys, p_value, err_value in zip(var_pars, p, err):
|
||||
self.parameter[keys].scaled_value = p_value
|
||||
self.parameter[keys].scaled_error = err_value
|
||||
if keys in data.parameter.keys():
|
||||
data.parameter[keys].scaled_value = p_value
|
||||
data.parameter[keys].scaled_error = err_value
|
||||
data.parameter[keys].namespace[keys] = data.parameter[keys].value
|
||||
|
||||
combinations = list(product(var_pars, var_pars))
|
||||
actual_parameters = []
|
||||
corr_idx = []
|
||||
|
||||
for i, p_i in enumerate(used_pars):
|
||||
actual_parameters.append(self.parameter[p_i])
|
||||
actual_parameters.append(data.parameter[p_i])
|
||||
for j, p_j in enumerate(used_pars):
|
||||
try:
|
||||
# find the position of the parameter combinations
|
||||
@ -508,3 +519,4 @@ class FitRoutine(object):
|
||||
partial_corr = corr
|
||||
|
||||
return _err, corr, partial_corr
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Sized
|
||||
from numpy import inf
|
||||
|
||||
from ._meta import MultiModel
|
||||
from .parameter import Parameters
|
||||
from .parameter import Parameters, Parameter
|
||||
|
||||
|
||||
class Model(object):
|
||||
@ -25,7 +25,6 @@ class Model(object):
|
||||
self.ub = [i if i is not None else inf for i in self.ub]
|
||||
|
||||
self.parameter = Parameters()
|
||||
self.global_parameter = {}
|
||||
self.is_complex = None
|
||||
self._complex_part = False
|
||||
|
||||
@ -80,23 +79,33 @@ class Model(object):
|
||||
self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items()
|
||||
if v.default is not inspect.Parameter.empty}
|
||||
|
||||
def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False):
|
||||
if idx is None:
|
||||
self.parameter = Parameters()
|
||||
self.global_parameter = {}
|
||||
return
|
||||
def set_global_parameter(self,
|
||||
key: str | Parameter,
|
||||
value: float | str = None,
|
||||
*,
|
||||
var: bool = None,
|
||||
lb: float = None,
|
||||
ub: float = None,
|
||||
default_bounds: bool = False,
|
||||
) -> Parameter:
|
||||
|
||||
if isinstance(key, Parameter):
|
||||
p = key
|
||||
key = f'p{next(Parameters.parameter_counter)}'
|
||||
self.parameter.add_parameter(key, p)
|
||||
|
||||
else:
|
||||
idx = [self.params.index(key)]
|
||||
if default_bounds:
|
||||
if lb is None:
|
||||
lb = [self.lb[i] for i in idx]
|
||||
if ub is None:
|
||||
ub = [self.lb[i] for i in idx]
|
||||
|
||||
gp = self.parameter.add_parameter(p, var=var, lb=lb, ub=ub)
|
||||
for k, v in zip(idx, gp):
|
||||
self.global_parameter[k] = v
|
||||
p = self.parameter.add(key, value, var=var, lb=lb, ub=ub)
|
||||
p.is_global = True
|
||||
|
||||
return gp
|
||||
return p
|
||||
|
||||
@staticmethod
|
||||
def _prep(param_len, val):
|
||||
|
@ -1,94 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
import re
|
||||
from itertools import count
|
||||
|
||||
from io import StringIO
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Parameters(dict):
|
||||
count = count()
|
||||
parameter_counter = count()
|
||||
# is one global namespace a good idea?
|
||||
namespace: dict = {}
|
||||
|
||||
def __str__(self):
|
||||
return 'Parameters:\n' + '\n'.join([str(k)+': '+str(v) for k, v in self.items()])
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._mapping: dict = {}
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (list, tuple, np.ndarray)):
|
||||
values = []
|
||||
for item_i in item:
|
||||
values.append(super().__getitem__(item_i))
|
||||
return values
|
||||
def __str__(self) -> str:
|
||||
return 'Parameters:\n' + '\n'.join([f'{k}: {v}' for k, v in self.items()])
|
||||
|
||||
def __getitem__(self, item) -> Parameter:
|
||||
if item in self._mapping:
|
||||
return super().__getitem__(self._mapping[item])
|
||||
else:
|
||||
return super().__getitem__(item)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.add_parameter(key, value)
|
||||
|
||||
def __contains__(self, item):
|
||||
for v in self.values():
|
||||
if item == v.name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def add(self,
|
||||
name: str,
|
||||
value: str | float | int = None,
|
||||
*,
|
||||
var: bool = True,
|
||||
lb: float = -np.inf, ub: float = np.inf) -> Parameter:
|
||||
|
||||
par = Parameter(name=name, value=value, var=var, lb=lb, ub=ub)
|
||||
key = f'p{next(Parameters.parameter_counter)}'
|
||||
|
||||
self.add_parameter(key, par)
|
||||
|
||||
return par
|
||||
|
||||
def add_parameter(self, key: str, parameter: Parameter):
|
||||
self._mapping[parameter.name] = key
|
||||
super().__setitem__(key, parameter)
|
||||
|
||||
parameter.eval_allowed = False
|
||||
self.namespace[key] = parameter.value
|
||||
parameter.namespace = self.namespace
|
||||
parameter.eval_allowed = True
|
||||
|
||||
self.update_namespace()
|
||||
|
||||
def replace_parameter(self, key_out: str, key_in: str, parameter: Parameter):
|
||||
self.add_parameter(key_in, parameter)
|
||||
for k, v in self._mapping.items():
|
||||
if v == key_out:
|
||||
self._mapping[k] = key_in
|
||||
break
|
||||
|
||||
if key_out in self.namespace:
|
||||
del self.namespace[key_out]
|
||||
|
||||
def fix(self):
|
||||
for v in self.keys():
|
||||
v._value = v.value
|
||||
v.namespace = {}
|
||||
|
||||
@staticmethod
|
||||
def _prep_bounds(val, p_len: int) -> list:
|
||||
# helper function to ensure that bounds and variable are of parameter shape
|
||||
if isinstance(val, (Number, bool)) or val is None:
|
||||
return [val] * p_len
|
||||
def reset():
|
||||
Parameters.namespace = {}
|
||||
|
||||
elif len(val) == p_len:
|
||||
return val
|
||||
|
||||
elif len(val) == 1:
|
||||
return [val[0]] * p_len
|
||||
|
||||
else:
|
||||
raise ValueError('Input {} has wrong dimensions'.format(val))
|
||||
|
||||
def add_parameter(self, param, var=None, lb=None, ub=None):
|
||||
if isinstance(param, Number):
|
||||
param = [param]
|
||||
|
||||
p_len = len(param)
|
||||
|
||||
# make list if only single value is given
|
||||
var = self._prep_bounds(var, p_len)
|
||||
lb = self._prep_bounds(lb, p_len)
|
||||
ub = self._prep_bounds(ub, p_len)
|
||||
|
||||
new_keys = []
|
||||
for i in range(p_len):
|
||||
new_idx = next(self.count)
|
||||
new_keys.append(new_idx)
|
||||
|
||||
self[new_idx] = Parameter(param[i], var=var[i], lb=lb[i], ub=ub[i])
|
||||
|
||||
return new_keys
|
||||
|
||||
def copy(self):
|
||||
p = Parameters()
|
||||
def get_key(self, name: str) -> str | None:
|
||||
for k, v in self.items():
|
||||
p[k] = Parameter(v.value, var=v.var, lb=v.lb, ub=v.ub)
|
||||
if name == v.name:
|
||||
return k
|
||||
|
||||
if len(p) == 0:
|
||||
return p
|
||||
|
||||
max_k = max(p.keys())
|
||||
c = next(p.count)
|
||||
while c < max_k:
|
||||
c = next(p.count)
|
||||
|
||||
return p
|
||||
return
|
||||
|
||||
def get_state(self):
|
||||
return {k: v.get_state() for k, v in self.items()}
|
||||
|
||||
def update_namespace(self):
|
||||
for p in self.values():
|
||||
try:
|
||||
p.value
|
||||
except NameError:
|
||||
expression = p._expr_disp
|
||||
for n, k in self._mapping.items():
|
||||
expression, num_replaced = re.subn(re.escape(n), k, expression)
|
||||
|
||||
p._expr = expression
|
||||
|
||||
|
||||
class Parameter:
|
||||
"""
|
||||
Container for one parameter
|
||||
"""
|
||||
__slots__ = ['name', 'value', 'error', 'init_val', 'var', 'lb', 'ub', 'scale', 'function']
|
||||
|
||||
def __init__(self, value: float, var: bool = True, lb: float = -np.inf, ub: float = np.inf):
|
||||
self.lb = lb if lb is not None else -np.inf
|
||||
self.ub = ub if ub is not None else np.inf
|
||||
# TODO Parameter should know its own key
|
||||
def __init__(self, name: str, value: float | str, var: bool = True, lb: float = -np.inf, ub: float = np.inf):
|
||||
self._value: float | None = None
|
||||
self.var: bool = bool(var) if var is not None else True
|
||||
self.error: None | float = None if self.var is False else 0.0
|
||||
self.name: str = name
|
||||
self.function: str = ""
|
||||
|
||||
if self.lb <= value <= self.ub:
|
||||
self.value = value
|
||||
self.lb: None | float = lb if lb is not None else -np.inf
|
||||
self.ub: float | None = ub if ub is not None else np.inf
|
||||
self.namespace: dict = {}
|
||||
self.eval_allowed: bool = True
|
||||
self._expr: None | str = None
|
||||
self._expr_disp: None | str = None
|
||||
self.is_global = False
|
||||
|
||||
if isinstance(value, str):
|
||||
self._expr_disp = value
|
||||
self._expr = value
|
||||
self.var = False
|
||||
else:
|
||||
if self.lb <= value <= self.ub:
|
||||
self._value = value
|
||||
else:
|
||||
print(value, self.lb, self.ub)
|
||||
raise ValueError('Value of parameter is outside bounds')
|
||||
|
||||
self.init_val = value
|
||||
@ -100,25 +140,31 @@ class Parameter:
|
||||
if self.scale == 0:
|
||||
self.scale = 1.
|
||||
|
||||
self.var = bool(var) if var is not None else True
|
||||
self.error = None if self.var is False else 0.0
|
||||
self.name = ''
|
||||
self.function = ''
|
||||
|
||||
def __str__(self):
|
||||
start = ''
|
||||
def __str__(self) -> str:
|
||||
start = StringIO()
|
||||
if self.name:
|
||||
if self.function:
|
||||
start = f'{self.name} ({self.function}): '
|
||||
start.write(f"{self.name} ({self.function})")
|
||||
else:
|
||||
start = self.name + ': '
|
||||
start.write(self.name)
|
||||
|
||||
if self.is_global:
|
||||
start.write("*")
|
||||
|
||||
start.write(": ")
|
||||
|
||||
if self.var:
|
||||
return start + f'{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}'
|
||||
start.write(f"{self.value:.4g} +/- {self.error:.4g}, init={self.init_val}")
|
||||
else:
|
||||
return start + f'{self.value:} (fixed)'
|
||||
start.write(f"{self.value:.4g}")
|
||||
if self._expr is None:
|
||||
start.write(" (fixed)")
|
||||
else:
|
||||
start.write(f" (calc: {self._expr_disp})")
|
||||
|
||||
def __add__(self, other: Parameter | float) -> float:
|
||||
return start.getvalue()
|
||||
|
||||
def __add__(self, other: Parameter | float | int) -> float:
|
||||
if isinstance(other, (float, int)):
|
||||
return self.value + other
|
||||
elif isinstance(other, Parameter):
|
||||
@ -128,30 +174,39 @@ class Parameter:
|
||||
return self.__add__(other)
|
||||
|
||||
@property
|
||||
def scaled_value(self):
|
||||
def scaled_value(self) -> float:
|
||||
return self.value / self.scale
|
||||
|
||||
@scaled_value.setter
|
||||
def scaled_value(self, value):
|
||||
self.value = value * self.scale
|
||||
def scaled_value(self, value: float) -> None:
|
||||
self._value = value * self.scale
|
||||
|
||||
@property
|
||||
def scaled_error(self):
|
||||
if self.error is None:
|
||||
return self.error
|
||||
else:
|
||||
def value(self) -> float | None:
|
||||
if self._value is not None:
|
||||
return self._value
|
||||
|
||||
if self._expr is not None and self.eval_allowed:
|
||||
return eval(self._expr, {}, self.namespace)
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
def scaled_error(self) -> None | float:
|
||||
if self.error is not None:
|
||||
return self.error / self.scale
|
||||
|
||||
return
|
||||
|
||||
@scaled_error.setter
|
||||
def scaled_error(self, value):
|
||||
def scaled_error(self, value) -> None:
|
||||
self.error = value * self.scale
|
||||
|
||||
def get_state(self):
|
||||
|
||||
def get_state(self) -> dict:
|
||||
return {slot: getattr(self, slot) for slot in self.__slots__}
|
||||
|
||||
@staticmethod
|
||||
def set_state(state: dict):
|
||||
def set_state(state: dict) -> Parameter:
|
||||
par = Parameter(state.pop('value'))
|
||||
for k, v in state.items():
|
||||
setattr(par, k, v)
|
||||
@ -159,9 +214,28 @@ class Parameter:
|
||||
return par
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
def full_name(self) -> str:
|
||||
name = self.name
|
||||
if self.function:
|
||||
name += ' (' + self.function + ')'
|
||||
name += f" ({self.function})"
|
||||
|
||||
return name
|
||||
|
||||
def copy(self) -> Parameter:
|
||||
if self._expr:
|
||||
val = self._expr_disp
|
||||
else:
|
||||
val = self._value
|
||||
para_copy = Parameter(name=self.name, value=val, var=self.var, lb=self.lb, ub=self.ub)
|
||||
para_copy._expr = self._expr
|
||||
para_copy.namespace = self.namespace
|
||||
para_copy.is_global = self.is_global
|
||||
para_copy.error = self.error
|
||||
para_copy.function = self.function
|
||||
|
||||
return para_copy
|
||||
|
||||
def fix(self):
|
||||
self._value = self.value
|
||||
self.namespace = {}
|
||||
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -186,7 +187,7 @@ class FitResult(Points):
|
||||
nice_name = m.group(1)
|
||||
if func_number in split_funcs:
|
||||
nice_func = split_funcs[func_number]
|
||||
|
||||
pvalue.fix()
|
||||
pvalue.name = nice_name
|
||||
pvalue.function = nice_func
|
||||
parameter_dic[pname] = pvalue
|
||||
@ -223,27 +224,30 @@ class FitResult(Points):
|
||||
return self.nobs-self.nvar
|
||||
|
||||
def pprint(self, statistics=True, correlations=True):
|
||||
print('Fit result:')
|
||||
print(' model :', self.name)
|
||||
print(' #data :', self.nobs)
|
||||
print(' #var :', self.nvar)
|
||||
print('\nParameter')
|
||||
print(self.parameter_string())
|
||||
s = StringIO()
|
||||
s.write('Fit result:\n')
|
||||
s.write(f' model : {self.name}\n')
|
||||
s.write(f' #data : {self.nobs}\n')
|
||||
s.write(f' #var : {self.nvar}\n')
|
||||
s.write('\nParameter\n')
|
||||
s.write(self.parameter_string())
|
||||
|
||||
if statistics:
|
||||
print('Statistics')
|
||||
s.write('\nStatistics\n')
|
||||
for k, v in self.statistics.items():
|
||||
print(f' {k} : {v:.4f}')
|
||||
s.write(f' {k} : {v:.4f}\n')
|
||||
|
||||
if correlations and self.correlation is not None:
|
||||
print('\nCorrelation (partial corr.)')
|
||||
print(self._correlation_string())
|
||||
print()
|
||||
s.write('\nCorrelation (partial corr.)\n')
|
||||
s.write(self._correlation_string())
|
||||
s.write('\n')
|
||||
|
||||
print(s.getvalue())
|
||||
|
||||
def parameter_string(self):
|
||||
ret_val = ''
|
||||
|
||||
for pval in self.parameter.values():
|
||||
for pkey, pval in self.parameter.items():
|
||||
ret_val += convert(str(pval), old='tex', new='str') + '\n'
|
||||
|
||||
if self.fun_kwargs:
|
||||
@ -255,9 +259,7 @@ class FitResult(Points):
|
||||
def _correlation_string(self):
|
||||
ret_val = ''
|
||||
for p_i, p_j, corr_ij, pcorr_ij in self.correlation_list():
|
||||
ret_val += ' {} / {} : {:.4f} ({:.4f})\n'.format(convert(p_i, old='tex', new='str'),
|
||||
convert(p_j, old='tex', new='str'),
|
||||
corr_ij, pcorr_ij)
|
||||
ret_val += f" {convert(p_i, old='tex', new='str')} / {convert(p_j, old='tex', new='str')} : {corr_ij:.4f} ({pcorr_ij:.4f})\n"
|
||||
return ret_val
|
||||
|
||||
def correlation_list(self, limit=0.1):
|
||||
|
@ -35,8 +35,8 @@ class Gaussian:
|
||||
class Lorentzian:
|
||||
type = 'Spectrum'
|
||||
name = 'Lorentzian'
|
||||
equation = 'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}'
|
||||
params = ['A', '\mu', 'w', 'A_{0}']
|
||||
equation = r'A (2/\pi)w/[4*(x-\mu)^{2} + w^{2}] + A_{0}'
|
||||
params = ['A', r'\mu', 'w', 'A_{0}']
|
||||
ext_params = None
|
||||
bounds = [(0, None), (None, None), (0, None), (None, None)]
|
||||
|
||||
@ -62,9 +62,9 @@ class Lorentzian:
|
||||
class PseudoVoigt:
|
||||
type = 'Spectrum'
|
||||
name = 'Pseudo Voigt'
|
||||
equation = 'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \
|
||||
'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}'
|
||||
params = ['A', 'R', '\mu', 'w', 'A_{0}']
|
||||
equation = r'A [R*2/\pi*w/[4*(x-\mu)^{2} + w^{2}] + ' \
|
||||
r'(1-R)*sqrt(4*ln(2)/pi)/w*exp(-4*ln(2)[(x-\mu)/w]^{2})] + A_{0}'
|
||||
params = ['A', 'R', r'\mu', 'w', 'A_{0}']
|
||||
ext_params = None
|
||||
bounds = [(0, None), (0, 1), (None, None), (0, None)]
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
<x>0</x>
|
||||
<y>0</y>
|
||||
<width>365</width>
|
||||
<height>78</height>
|
||||
<height>66</height>
|
||||
</rect>
|
||||
</property>
|
||||
<property name="sizePolicy">
|
||||
@ -62,7 +62,7 @@
|
||||
<item>
|
||||
<widget class="LineEdit" name="parameter_line">
|
||||
<property name="sizePolicy">
|
||||
<sizepolicy hsizetype="Fixed" vsizetype="Fixed">
|
||||
<sizepolicy hsizetype="MinimumExpanding" vsizetype="Fixed">
|
||||
<horstretch>0</horstretch>
|
||||
<verstretch>0</verstretch>
|
||||
</sizepolicy>
|
||||
@ -78,19 +78,6 @@
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<spacer name="horizontalSpacer">
|
||||
<property name="orientation">
|
||||
<enum>Qt::Horizontal</enum>
|
||||
</property>
|
||||
<property name="sizeHint" stdset="0">
|
||||
<size>
|
||||
<width>40</width>
|
||||
<height>20</height>
|
||||
</size>
|
||||
</property>
|
||||
</spacer>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QCheckBox" name="fixed_check">
|
||||
<property name="text">
|
||||
@ -105,19 +92,6 @@
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
<item>
|
||||
<widget class="QToolButton" name="toolButton">
|
||||
<property name="text">
|
||||
<string/>
|
||||
</property>
|
||||
<property name="popupMode">
|
||||
<enum>QToolButton::InstantPopup</enum>
|
||||
</property>
|
||||
<property name="arrowType">
|
||||
<enum>Qt::RightArrow</enum>
|
||||
</property>
|
||||
</widget>
|
||||
</item>
|
||||
</layout>
|
||||
</item>
|
||||
<item>
|
||||
|
Loading…
Reference in New Issue
Block a user