forked from IPKM/nmreval
239 lines
8.1 KiB
Python
239 lines
8.1 KiB
Python
import re
|
|
|
|
import numexpr as ne
|
|
import numpy as np
|
|
|
|
from ..Qt import QtCore, QtWidgets
|
|
from .._py.fitcreationdialog import Ui_Dialog
|
|
|
|
|
|
_numexpr_funcs = []
|
|
for k, _ in ne.expressions.functions.items():
|
|
pat = k + r'\('
|
|
_numexpr_funcs.append((re.compile(pat), 'np.' + k + '('))
|
|
|
|
|
|
class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog):
|
|
classCreated = QtCore.pyqtSignal(object)
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent=parent)
|
|
self.setupUi(self)
|
|
|
|
self.namespace_widget.make_namespace()
|
|
|
|
self.tableWidget.itemChanged.connect(self.update_function)
|
|
|
|
self.groupBox.toggled.connect(self.change_visibility)
|
|
self.groupBox_2.toggled.connect(self.change_visibility)
|
|
self.groupBox_3.toggled.connect(self.change_visibility)
|
|
self.groupBox_4.toggled.connect(self.change_visibility)
|
|
|
|
self.groupBox.setChecked(True)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
for w in [self.lineEdit_4, self.lineEdit, self.lineEdit_3, self.lineEdit_2,
|
|
self.parameterLineEdit, self.externalParametersLineEdit]:
|
|
w.clear()
|
|
|
|
def check(self):
|
|
self.name = self.name_lineedit.text()
|
|
self.group = self.group_lineedit.text()
|
|
self.eq = str(self.lineEdit.text())
|
|
self.p = str(self.parameterLineEdit.text()).split()
|
|
self.func = str(self.lineEdit_4.text())
|
|
self._func_string = ''
|
|
|
|
error = []
|
|
for k, v in [('Name', self.name), ('Group', self.group), ('Parameters', self.p), ('Function', self.func)]:
|
|
if not v:
|
|
error.append('Empty ' + str(k))
|
|
if self.name:
|
|
if self.name[0].isdigit():
|
|
error.append('Name starts with digit')
|
|
if self.p:
|
|
if set(self.p) & set(self.ext_p):
|
|
error.append('Duplicate entries: {}'.format(list(set(self.p) & set(self.ext_p))))
|
|
if self.p and self.func:
|
|
p_test = np.ones((len(self.p)+len(self.ext_p)))
|
|
_x = np.arange(2)
|
|
namespace = {'x': _x}
|
|
for i, pp in enumerate(p_test):
|
|
namespace[f'p_{i}'] = pp
|
|
self._func_string = self.func + ''
|
|
self._func_string = self._func_string.replace('[', '_').replace(']', '')
|
|
try:
|
|
ne.evaluate(self._func_string, local_dict=namespace)
|
|
except KeyError:
|
|
error.append(f'Incorrect evaluation {self.func}')
|
|
|
|
if error:
|
|
QtWidgets.QMessageBox().warning(self, 'Invalid entries', '\n'.join(error))
|
|
else:
|
|
return True
|
|
|
|
def accept(self):
|
|
self.confirm()
|
|
super().accept()
|
|
|
|
def confirm(self):
|
|
print(f' name = {self.name_lineedit.text()}')
|
|
group_type = self.group_lineedit.text()
|
|
if group_type:
|
|
print(f' group = "{group_type}"')
|
|
else:
|
|
print(' group = "User-defined"')
|
|
var = []
|
|
for row in range(self.tableWidget.rowCount()):
|
|
var.append(self.tableWidget.item(row, 1).text())
|
|
if var:
|
|
print(' params = [r"', end='')
|
|
print('", r"'.join(var) + '"]')
|
|
else:
|
|
print(' params = []')
|
|
|
|
print('\n@staticmethod')
|
|
print(self.label.text())
|
|
import inspect
|
|
for k, v in self.namespace_widget.namespace.flatten().items():
|
|
if inspect.isfunction(v):
|
|
print(k, inspect.getmodule(v))
|
|
print(k, [cc[1] for cc in inspect.getmembers(v) if cc[0] == '__qualname__'])
|
|
else:
|
|
print(k, v)
|
|
print(self.plainTextEdit.toPlainText())
|
|
|
|
@QtCore.pyqtSlot(name='on_parameter_button_clicked')
|
|
def add_variable(self):
|
|
self.tableWidget.blockSignals(True)
|
|
row = self.tableWidget.rowCount()
|
|
self.tableWidget.setRowCount(row+1)
|
|
|
|
self.tableWidget.setItem(row, 0, QtWidgets.QTableWidgetItem('p'+str(row)))
|
|
self.tableWidget.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{'+str(row)+'}'))
|
|
self.tableWidget.setItem(row, 2, QtWidgets.QTableWidgetItem('--'))
|
|
self.tableWidget.setItem(row, 3, QtWidgets.QTableWidgetItem('--'))
|
|
self.tableWidget.blockSignals(False)
|
|
self.update_function(None)
|
|
|
|
@QtCore.pyqtSlot(name='on_selection_button_clicked')
|
|
def add_choice(self):
|
|
cnt = self.tabWidget.count()
|
|
self.tabWidget.addTab(ChoiceWidget(self), 'choice' + str(cnt))
|
|
self.tabWidget.setCurrentIndex(cnt)
|
|
|
|
def register(self):
|
|
i = 0
|
|
basename = self.name.replace(' ', '')
|
|
classname = basename
|
|
# while classname in _userfits:
|
|
# classname = basename + '_' + str(i)
|
|
# i += 1
|
|
c = register_class(classname, self.name, self.group, self.p, self.eq, self._func_string)
|
|
self.classCreated.emit(c)
|
|
return classname, c
|
|
|
|
def save(self, cname):
|
|
t = '\n# Created automatically\n' \
|
|
'class {cname:}(object):\n'\
|
|
' name = "{name:}"\n' \
|
|
' type = "{group:}"\n'\
|
|
' equation = "{eq:}"\n'\
|
|
' params = {p:}\n' \
|
|
' ext_params = {ep:}\n\n' \
|
|
' @staticmethod\n' \
|
|
' def func(p, x):\n' \
|
|
' return {func:}\n'
|
|
f_string = self.func
|
|
for pat, repl in _numexpr_funcs:
|
|
f_string = re.sub(pat, repl, f_string)
|
|
|
|
@QtCore.pyqtSlot(QtWidgets.QTableWidgetItem)
|
|
@QtCore.pyqtSlot(str)
|
|
def update_function(self, _):
|
|
var = []
|
|
for row in range(self.tableWidget.rowCount()):
|
|
var.append(self.tableWidget.item(row, 0).text())
|
|
|
|
if self.use_nuclei.isChecked():
|
|
var.append('nucleus=2.67522128e8')
|
|
|
|
# for row in range(self.selection_combobox.count()):
|
|
# var.append(self.selection_combobox.itemText(row) + '=')
|
|
|
|
self.label.setText('def func(x, ' + ', '.join(var) + '):')
|
|
|
|
def change_visibility(self):
|
|
sender = self.sender()
|
|
|
|
for gb in [self.groupBox, self.groupBox_2, self.groupBox_3, self.groupBox_4]:
|
|
gb.blockSignals(True)
|
|
gb.setChecked(sender == gb)
|
|
gb.blockSignals(False)
|
|
|
|
self.widget_2.setVisible(sender == self.groupBox)
|
|
self.widget_3.setVisible(sender == self.groupBox_2)
|
|
self.widget.setVisible(sender == self.groupBox_3)
|
|
self.namespace_widget.setVisible(sender == self.groupBox_4)
|
|
|
|
|
|
class ChoiceWidget(QtWidgets.QWidget):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent=parent)
|
|
|
|
self._init_ui()
|
|
|
|
def _init_ui(self):
|
|
layout = QtWidgets.QGridLayout()
|
|
layout.setContentsMargins(3, 3, 3, 3)
|
|
layout.setHorizontalSpacing(6)
|
|
|
|
self.label = QtWidgets.QLabel('Name', parent=self)
|
|
layout.addWidget(self.label, 0, 0)
|
|
self.name_line = QtWidgets.QLineEdit(self)
|
|
layout.addWidget(self.name_line, 0, 1)
|
|
|
|
self.label_2 = QtWidgets.QLabel('Displayed name', parent=self)
|
|
layout.addWidget(self.label_2, 0, 2)
|
|
self.display_line = QtWidgets.QLineEdit(self)
|
|
layout.addWidget(self.display_line, 0, 3)
|
|
|
|
self.label_3 = QtWidgets.QLabel('Type', parent=self)
|
|
layout.addWidget(self.label_3)
|
|
self.types = QtWidgets.QComboBox(self)
|
|
self.types.addItems(['str', 'int', 'float'])
|
|
layout.addWidget(self.types)
|
|
|
|
self.add_button = QtWidgets.QPushButton('Add option')
|
|
layout.addWidget(self.add_button)
|
|
|
|
self.table = QtWidgets.QTableWidget(self)
|
|
self.table.setColumnCount(2)
|
|
self.table.setHorizontalHeaderLabels(['Name', 'Value'])
|
|
layout.addWidget(self.table, 2, 0, 1, 4)
|
|
|
|
self.setLayout(layout)
|
|
|
|
|
|
def register_class(cname, name, group, p, eq, func):
|
|
c = type(cname, (), {})
|
|
c.name = name
|
|
c.type = group
|
|
c.params = p
|
|
c.equation = eq
|
|
c.func = func_decorator(func)
|
|
|
|
return c
|
|
|
|
|
|
def func_decorator(f_string):
|
|
# we need this decorator because the result is used in a class
|
|
|
|
def wrapped_f(*args):
|
|
namespace = {'x': args[1]}
|
|
for i, pp in enumerate(args[0]):
|
|
namespace['p_{}'.format(i)] = pp
|
|
return ne.evaluate(f_string, local_dict=namespace)
|
|
|
|
return wrapped_f
|