nmreval/src/gui_qt/fit/fit_parameter.py

344 lines
12 KiB
Python
Raw Normal View History

from __future__ import annotations
2022-10-20 15:23:15 +00:00
from nmreval.utils.text import convert
2022-03-08 09:27:40 +00:00
from ..Qt import QtWidgets, QtCore, QtGui
from .._py.fitfuncwidget import Ui_FormFit
from ..lib.forms import SelectionWidget
2022-03-08 09:27:40 +00:00
from .fit_forms import FitModelWidget
class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
value_requested = QtCore.pyqtSignal(int)
def __init__(self, parent=None):
super().__init__(parent=parent)
self.setupUi(self)
self.func = None
self.func_idx = None
self.max_width = QtCore.QSize(0, 0)
self.global_parameter = []
self.data_parameter = []
self.glob_values = None
self.data_values = {}
self.scrollwidget.setLayout(QtWidgets.QVBoxLayout())
self.scrollwidget2.setLayout(QtWidgets.QVBoxLayout())
def eventFilter(self, src: QtCore.QObject, evt: QtCore.QEvent):
if isinstance(evt, QtGui.QKeyEvent):
if (evt.key() == QtCore.Qt.Key_Right) and \
(evt.modifiers() == QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier):
self.change_single_parameter(src.value, sender=src)
self.select_next_preview(1)
return True
elif (evt.key() == QtCore.Qt.Key_Left) and \
(evt.modifiers() == QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier):
self.change_single_parameter(src.value, sender=src)
self.select_next_preview(-1)
return True
return super().eventFilter(src, evt)
def load(self, data):
self.comboBox.blockSignals(True)
while self.comboBox.count():
self.comboBox.removeItem(0)
for sid, name in data:
self.comboBox.addItem(name, userData=sid)
self._make_parameter(sid)
self.comboBox.blockSignals(False)
self.change_data(0)
2022-03-08 09:27:40 +00:00
def set_function(self, func, idx):
self.func = func
self.func_idx = idx
self.glob_values = [1] * len(func.params)
for k, v in enumerate(func.params):
name = convert(v)
widgt = FitModelWidget(label=name, parent=self.scrollwidget)
widgt.parameter_pos = k
widgt.func_idx = idx
try:
widgt.set_bounds(*func.bounds[k], False)
except (AttributeError, IndexError):
pass
size = widgt.parametername.sizeHint()
if self.max_width.width() < size.width():
self.max_width = size
widgt.state_changed.connect(self.make_global)
2022-03-08 09:27:40 +00:00
widgt.value_requested.connect(self.look_for_value)
widgt.value_changed.connect(self.change_global_parameter)
self.global_parameter.append(widgt)
self.scrollwidget.layout().addWidget(widgt)
widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2)
2022-03-08 09:27:40 +00:00
widgt2.valueChanged.connect(self.change_single_parameter)
widgt2.removeSingleValue.connect(self.change_single_parameter)
2022-03-08 09:27:40 +00:00
widgt2.installEventFilter(self)
self.scrollwidget2.layout().addWidget(widgt2)
self.data_parameter.append(widgt2)
for w1, w2 in zip(self.global_parameter, self.data_parameter):
w1.parametername.setFixedSize(self.max_width)
w1.checkBox.setFixedSize(self.max_width)
w2.label.setFixedSize(self.max_width)
if hasattr(func, 'choices') and func.choices is not None:
cbox = func.choices
for c in cbox:
widgt = SelectionWidget(*c)
widgt.selectionChanged.connect(self.change_global_choice)
self.global_parameter.append(widgt)
self.glob_values.append(widgt.value)
self.scrollwidget.layout().addWidget(widgt)
widgt2 = SelectionWidget(*c)
widgt2.selectionChanged.connect(self.change_single_choice)
self.data_parameter.append(widgt2)
self.scrollwidget2.layout().addWidget(widgt2)
for i in range(self.comboBox.count()):
self._make_parameter(self.comboBox.itemData(i))
self.scrollwidget.layout().addStretch(1)
self.scrollwidget2.layout().addStretch(1)
def set_links(self, parameter):
for w in self.global_parameter:
if isinstance(w, FitModelWidget):
w.add_links(parameter)
@QtCore.pyqtSlot(str)
def change_global_parameter(self, value: str, idx: int = None):
if idx is None:
idx = self.global_parameter.index(self.sender())
2022-03-08 09:27:40 +00:00
self.glob_values[idx] = float(value)
if self.data_values[self.comboBox.currentData()][idx] is None:
self.data_parameter[idx].blockSignals(True)
self.data_parameter[idx].value = float(value)
2022-03-08 09:27:40 +00:00
self.data_parameter[idx].blockSignals(False)
@QtCore.pyqtSlot(str, object)
2022-03-24 16:35:10 +00:00
def change_global_choice(self, _, value):
2022-03-08 09:27:40 +00:00
idx = self.global_parameter.index(self.sender())
self.glob_values[idx] = value
if self.data_values[self.comboBox.currentData()][idx] is None:
self.data_parameter[idx].blockSignals(True)
self.data_parameter[idx].value = value
self.data_parameter[idx].blockSignals(False)
def change_single_parameter(self, value: float = None, sender=None):
2022-03-08 09:27:40 +00:00
if sender is None:
sender = self.sender()
idx = self.data_parameter.index(sender)
self.data_values[self.comboBox.currentData()][idx] = value
# look for global parameter values if value is reset, ie None
if value is None:
self.change_data(self.comboBox.currentIndex())
2022-03-08 09:27:40 +00:00
def change_single_choice(self, _, value, sender=None):
2022-03-08 09:27:40 +00:00
if sender is None:
sender = self.sender()
idx = self.data_parameter.index(sender)
self.data_values[self.comboBox.currentData()][idx] = value
@QtCore.pyqtSlot(object)
def look_for_value(self, sender):
self.value_requested.emit(self.global_parameter.index(sender))
@QtCore.pyqtSlot()
def make_global(self):
2022-03-08 09:27:40 +00:00
# disable single parameter if it is set global, enable if global is unset
widget = self.sender()
idx = self.global_parameter.index(widget)
enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None)
self.data_parameter[idx].setEnabled(enable)
def select_next_preview(self, direction):
curr_idx = self.comboBox.currentIndex()
next_idx = (curr_idx + direction) % self.comboBox.count()
self.comboBox.setCurrentIndex(next_idx)
@QtCore.pyqtSlot(int, name='on_comboBox_currentIndexChanged')
def change_data(self, idx: int):
# new dataset is selected, look for locally set parameter else use global values
sid = self.comboBox.itemData(idx)
if sid not in self.data_values:
self._make_parameter(sid)
for i, value in enumerate(self.data_values[sid]):
w = self.data_parameter[i]
w.blockSignals(True)
try:
w.show_as_local_parameter(value is not None)
except AttributeError:
pass
2022-03-08 09:27:40 +00:00
if value is None:
w.value = self.glob_values[i]
else:
w.value = value
w.blockSignals(False)
def _make_parameter(self, sid):
if sid not in self.data_values:
self.data_values[sid] = [None] * len(self.data_parameter)
def get_parameter(self, use_func=None):
bds = []
is_global = []
is_fixed = []
globs = []
is_linked = []
for g in self.global_parameter:
if isinstance(g, FitModelWidget):
p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter()
globs.append(p_i)
bds.append(bds_i)
is_fixed.append(fixed_i)
is_global.append(global_i)
is_linked.append(link_i)
lb, ub = list(zip(*bds))
data_parameter = {}
if use_func is None:
use_func = list(self.data_values.keys())
global_p = None
for sid, parameter in self.data_values.items():
if sid not in use_func:
continue
kw_p = {}
p = []
if global_p is None:
global_p = {'p': [], 'idx': [], 'var': [], 'ub': [], 'lb': []}
for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)):
if isinstance(g, FitModelWidget):
if (p_i is None) or is_global[i]:
p.append(globs[i])
if is_global[i]:
if i not in global_p['idx']:
global_p['p'].append(globs[i])
global_p['idx'].append(i)
global_p['var'].append(is_fixed[i])
global_p['ub'].append(ub[i])
global_p['lb'].append(lb[i])
else:
p.append(p_i)
try:
if p[i] > ub[i]:
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
except TypeError:
pass
2022-03-08 09:27:40 +00:00
try:
if p[i] < lb[i]:
raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
except TypeError:
pass
else:
if p_i is None:
kw_p.update(g.value)
2022-03-24 16:35:10 +00:00
elif isinstance(p_i, dict):
kw_p.update(p_i)
2022-03-08 09:27:40 +00:00
else:
kw_p[g.argname] = p_i
data_parameter[sid] = (p, kw_p)
return data_parameter, lb, ub, is_fixed, global_p, is_linked
def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
2023-05-19 13:48:32 +00:00
num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter))
param_len = len(num_parameter)
if set_id is None:
2023-05-19 13:48:32 +00:00
for i, g in enumerate(num_parameter):
val = parameter[i]
g.set_parameter(val)
2023-05-19 13:48:32 +00:00
self.glob_values[i] = val
else:
new_param = self.data_values[set_id]
2023-05-19 13:48:32 +00:00
min_len = min(param_len, len(new_param))
for i in range(min_len):
new_param[i] = parameter[i]
self.change_data(self.comboBox.currentIndex())
return param_len
class ParameterSingleWidget(QtWidgets.QWidget):
valueChanged = QtCore.pyqtSignal(object)
removeSingleValue = QtCore.pyqtSignal()
def __init__(self, name: str, parent=None):
super().__init__(parent=parent)
self._init_ui()
self._name = name
self.label.setText(convert(name))
self.label.setToolTip('IIf this is bold then this parameter is only for this data. otherwise the general parameter is used and displayed')
self.value_line.setValidator(QtGui.QDoubleValidator())
self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0)
self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit())
def _init_ui(self):
layout = QtWidgets.QHBoxLayout(self)
layout.setContentsMargins(2, 2, 2, 2)
layout.setSpacing(2)
self.label = QtWidgets.QLabel(self)
layout.addWidget(self.label)
layout.addSpacerItem(QtWidgets.QSpacerItem(0, 0, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum))
self.value_line = QtWidgets.QLineEdit(self)
self.value_line.textEdited.connect(lambda x: self.show_as_local_parameter(True))
layout.addWidget(self.value_line)
self.reset_button = QtWidgets.QToolButton(self)
self.reset_button.setText('Use global')
self.reset_button.clicked.connect(lambda: self.show_as_local_parameter(False))
layout.addWidget(self.reset_button)
self.setLayout(layout)
@property
def value(self) -> float:
try:
return float(self.value_line.text().replace(',', '.'))
except ValueError:
return 0.0
@value.setter
def value(self, val):
self.value_line.setText(f'{float(val):.5g}')
def show_as_local_parameter(self, is_local):
if is_local:
self.label.setStyleSheet('font-weight: bold;')
else:
self.label.setStyleSheet('')