1
0
forked from IPKM/nmreval
nmreval/nmreval/gui_qt/fit/function_creation_dialog.py
2022-03-24 20:24:28 +01:00

239 lines
8.1 KiB
Python

import re
import numexpr as ne
import numpy as np
from ..Qt import QtCore, QtWidgets
from .._py.fitcreationdialog import Ui_Dialog
_numexpr_funcs = []
for k, _ in ne.expressions.functions.items():
pat = k + r'\('
_numexpr_funcs.append((re.compile(pat), 'np.' + k + '('))
class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog):
classCreated = QtCore.pyqtSignal(object)
def __init__(self, parent=None):
super().__init__(parent=parent)
self.setupUi(self)
self.namespace_widget.make_namespace()
self.tableWidget.itemChanged.connect(self.update_function)
self.groupBox.toggled.connect(self.change_visibility)
self.groupBox_2.toggled.connect(self.change_visibility)
self.groupBox_3.toggled.connect(self.change_visibility)
self.groupBox_4.toggled.connect(self.change_visibility)
self.groupBox.setChecked(True)
def __call__(self, *args, **kwargs):
for w in [self.lineEdit_4, self.lineEdit, self.lineEdit_3, self.lineEdit_2,
self.parameterLineEdit, self.externalParametersLineEdit]:
w.clear()
def check(self):
self.name = self.name_lineedit.text()
self.group = self.group_lineedit.text()
self.eq = str(self.lineEdit.text())
self.p = str(self.parameterLineEdit.text()).split()
self.func = str(self.lineEdit_4.text())
self._func_string = ''
error = []
for k, v in [('Name', self.name), ('Group', self.group), ('Parameters', self.p), ('Function', self.func)]:
if not v:
error.append('Empty ' + str(k))
if self.name:
if self.name[0].isdigit():
error.append('Name starts with digit')
if self.p:
if set(self.p) & set(self.ext_p):
error.append('Duplicate entries: {}'.format(list(set(self.p) & set(self.ext_p))))
if self.p and self.func:
p_test = np.ones((len(self.p)+len(self.ext_p)))
_x = np.arange(2)
namespace = {'x': _x}
for i, pp in enumerate(p_test):
namespace[f'p_{i}'] = pp
self._func_string = self.func + ''
self._func_string = self._func_string.replace('[', '_').replace(']', '')
try:
ne.evaluate(self._func_string, local_dict=namespace)
except KeyError:
error.append(f'Incorrect evaluation {self.func}')
if error:
QtWidgets.QMessageBox().warning(self, 'Invalid entries', '\n'.join(error))
else:
return True
def accept(self):
self.confirm()
super().accept()
def confirm(self):
print(f' name = {self.name_lineedit.text()}')
group_type = self.group_lineedit.text()
if group_type:
print(f' group = "{group_type}"')
else:
print(' group = "User-defined"')
var = []
for row in range(self.tableWidget.rowCount()):
var.append(self.tableWidget.item(row, 1).text())
if var:
print(' params = [r"', end='')
print('", r"'.join(var) + '"]')
else:
print(' params = []')
print('\n@staticmethod')
print(self.label.text())
import inspect
for k, v in self.namespace_widget.namespace.flatten().items():
if inspect.isfunction(v):
print(k, inspect.getmodule(v))
print(k, [cc[1] for cc in inspect.getmembers(v) if cc[0] == '__qualname__'])
else:
print(k, v)
print(self.plainTextEdit.toPlainText())
@QtCore.pyqtSlot(name='on_parameter_button_clicked')
def add_variable(self):
self.tableWidget.blockSignals(True)
row = self.tableWidget.rowCount()
self.tableWidget.setRowCount(row+1)
self.tableWidget.setItem(row, 0, QtWidgets.QTableWidgetItem('p'+str(row)))
self.tableWidget.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{'+str(row)+'}'))
self.tableWidget.setItem(row, 2, QtWidgets.QTableWidgetItem('--'))
self.tableWidget.setItem(row, 3, QtWidgets.QTableWidgetItem('--'))
self.tableWidget.blockSignals(False)
self.update_function(None)
@QtCore.pyqtSlot(name='on_selection_button_clicked')
def add_choice(self):
cnt = self.tabWidget.count()
self.tabWidget.addTab(ChoiceWidget(self), 'choice' + str(cnt))
self.tabWidget.setCurrentIndex(cnt)
def register(self):
i = 0
basename = self.name.replace(' ', '')
classname = basename
# while classname in _userfits:
# classname = basename + '_' + str(i)
# i += 1
c = register_class(classname, self.name, self.group, self.p, self.eq, self._func_string)
self.classCreated.emit(c)
return classname, c
def save(self, cname):
t = '\n# Created automatically\n' \
'class {cname:}(object):\n'\
' name = "{name:}"\n' \
' type = "{group:}"\n'\
' equation = "{eq:}"\n'\
' params = {p:}\n' \
' ext_params = {ep:}\n\n' \
' @staticmethod\n' \
' def func(p, x):\n' \
' return {func:}\n'
f_string = self.func
for pat, repl in _numexpr_funcs:
f_string = re.sub(pat, repl, f_string)
@QtCore.pyqtSlot(QtWidgets.QTableWidgetItem)
@QtCore.pyqtSlot(str)
def update_function(self, _):
var = []
for row in range(self.tableWidget.rowCount()):
var.append(self.tableWidget.item(row, 0).text())
if self.use_nuclei.isChecked():
var.append('nucleus=2.67522128e8')
# for row in range(self.selection_combobox.count()):
# var.append(self.selection_combobox.itemText(row) + '=')
self.label.setText('def func(x, ' + ', '.join(var) + '):')
def change_visibility(self):
sender = self.sender()
for gb in [self.groupBox, self.groupBox_2, self.groupBox_3, self.groupBox_4]:
gb.blockSignals(True)
gb.setChecked(sender == gb)
gb.blockSignals(False)
self.widget_2.setVisible(sender == self.groupBox)
self.widget_3.setVisible(sender == self.groupBox_2)
self.widget.setVisible(sender == self.groupBox_3)
self.namespace_widget.setVisible(sender == self.groupBox_4)
class ChoiceWidget(QtWidgets.QWidget):
def __init__(self, parent=None):
super().__init__(parent=parent)
self._init_ui()
def _init_ui(self):
layout = QtWidgets.QGridLayout()
layout.setContentsMargins(3, 3, 3, 3)
layout.setHorizontalSpacing(6)
self.label = QtWidgets.QLabel('Name', parent=self)
layout.addWidget(self.label, 0, 0)
self.name_line = QtWidgets.QLineEdit(self)
layout.addWidget(self.name_line, 0, 1)
self.label_2 = QtWidgets.QLabel('Displayed name', parent=self)
layout.addWidget(self.label_2, 0, 2)
self.display_line = QtWidgets.QLineEdit(self)
layout.addWidget(self.display_line, 0, 3)
self.label_3 = QtWidgets.QLabel('Type', parent=self)
layout.addWidget(self.label_3)
self.types = QtWidgets.QComboBox(self)
self.types.addItems(['str', 'int', 'float'])
layout.addWidget(self.types)
self.add_button = QtWidgets.QPushButton('Add option')
layout.addWidget(self.add_button)
self.table = QtWidgets.QTableWidget(self)
self.table.setColumnCount(2)
self.table.setHorizontalHeaderLabels(['Name', 'Value'])
layout.addWidget(self.table, 2, 0, 1, 4)
self.setLayout(layout)
def register_class(cname, name, group, p, eq, func):
c = type(cname, (), {})
c.name = name
c.type = group
c.params = p
c.equation = eq
c.func = func_decorator(func)
return c
def func_decorator(f_string):
# we need this decorator because the result is used in a class
def wrapped_f(*args):
namespace = {'x': args[1]}
for i, pp in enumerate(args[0]):
namespace['p_{}'.format(i)] = pp
return ne.evaluate(f_string, local_dict=namespace)
return wrapped_f