nmreval/src/gui_qt/fit/function_creation_dialog.py

450 lines
15 KiB
Python
Raw Normal View History

2022-10-20 15:23:15 +00:00
from __future__ import annotations
import inspect
import numbers
import textwrap
from typing import Any
import numpy as np
from gui_qt.Qt import QtCore, QtWidgets, QtGui
from gui_qt._py.fitcreationdialog import Ui_Dialog
from gui_qt.lib.namespace import QNamespaceWidget
__all__ = ['QUserFitCreator']
validator = QtGui.QRegExpValidator(QtCore.QRegExp('[A-Za-z]\S*'))
class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog):
classCreated = QtCore.pyqtSignal(object)
def __init__(self, parent=None):
super().__init__(parent=parent)
self.setupUi(self)
self.description_widget = DescWidget(self)
self.args_widget = ArgWidget(self)
self.kwargs_widget = KwargsWidget(self)
self.kwargs_widget.Changed.connect(self.update_function)
self.namespace_widget = QNamespaceWidget(self)
self.namespace_widget.make_namespace()
self.namespace_widget.sendKey.connect(self.namespace_made)
for b, w in [(self.description_box, self.description_widget), (self.args_box, self.args_widget),
(self.kwargs_box, self.kwargs_widget), (self.namespace_box, self.namespace_widget)]:
b.layout().addWidget(w)
try:
w.Changed.connect(self.update_function)
except AttributeError:
pass
b.layout().addStretch()
self._imports = set()
self.update_function()
def __call__(self, *args, **kwargs):
return self
def update_function(self):
try:
var = self.args_widget.get_parameter()
var += self.kwargs_widget.get_parameter()
k = ''
for imps in self._imports:
if len(imps) == 2:
k += f'from {imps[0]} import {imps[1]}\n'
elif imps[0] == 'numpy':
k += 'import numpy as np\n'
if len(self._imports):
k += '\n\n'
k += self.description_widget.get_strings()
k += self.args_widget.get_strings()
k += self.kwargs_widget.get_strings()
k += '\n @staticmethod\n'
k += f" def func(x, {', '.join(var)}):\n"
self.plainTextEdit.setPlainText(k)
except Exception as e:
QtWidgets.QMessageBox.warning(self, 'Failure', f'Error found: {e.args[0]}')
def change_visibility(self):
sender = self.sender()
for box in (self.description_box, self.args_box, self.kwargs_box, self.namespace_box):
box.blockSignals(True)
box.setExpansion(sender == box)
box.blockSignals(False)
def namespace_made(self, invalue: str):
ns = self.namespace_widget.namespace.namespace
func_value = ns[invalue][0]
ret_func = ''
if func_value is None:
ret_func = invalue
elif isinstance(func_value, numbers.Number):
ret_func = func_value
elif isinstance(func_value, np.ufunc):
self._imports.add(('numpy',))
ret_func = 'np.'+func_value.__name__ + '(x)'
else:
f_string = ns[invalue][-1]
args = f_string[f_string.find('('):]
if inspect.ismethod(func_value):
ret_func = func_value.__self__.__name__ + '.func'+args
elif hasattr(func_value, '__qualname__'):
ret_func = func_value.__qualname__.split('.')[0]
self._imports.add((inspect.getmodule(func_value).__name__, ret_func))
self.plainTextEdit.insertPlainText(ret_func)
self.update_function()
class KwargsWidget(QtWidgets.QWidget):
Changed = QtCore.pyqtSignal()
def __init__(self, parent=None):
super().__init__(parent=parent)
self._num_kwargs = 0
self._setup_ui()
def _setup_ui(self):
layout = QtWidgets.QGridLayout()
layout.setContentsMargins(3, 3, 3, 3)
layout.setHorizontalSpacing(3)
self.use_nuclei = QtWidgets.QCheckBox('Add gyromagnetic ratio', self)
self.use_nuclei.stateChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.use_nuclei, 0, 0, 1, 3)
self.choices = QtWidgets.QTabWidget(self)
layout.addWidget(self.choices, 1, 0, 1, 3)
self.add_choice_button = QtWidgets.QPushButton('Add choice', self)
self.add_choice_button.clicked.connect(self.add_choice)
layout.addWidget(self.add_choice_button, 2, 0, 1, 1)
self.rem_choice_button = QtWidgets.QPushButton('Remove choice', self)
self.rem_choice_button.clicked.connect(self.remove_choice)
layout.addWidget(self.rem_choice_button, 2, 1, 1, 1)
self.setLayout(layout)
def add_choice(self):
cnt = self._num_kwargs
c = ChoiceWidget(cnt, self)
c.Changed.connect(self.update_choice)
self.choices.addTab(c, c.name_line.text())
self._num_kwargs += 1
self.choices.setCurrentIndex(cnt)
self.Changed.emit()
def remove_choice(self):
cnt = self.choices.currentIndex()
self.choices.removeTab(cnt)
self.Changed.emit()
def update_choice(self):
idx = self.choices.currentIndex()
self.choices.setTabText(idx, self.sender().name_line.text())
self.Changed.emit()
def get_parameter(self):
if self.use_nuclei.isChecked():
var = ['nucleus=2.67522128e7']
else:
var = []
var += [self.choices.widget(idx).get_parameter() for idx in range(self.choices.count())]
return var
def get_strings(self) -> str:
kwargs = []
if self.use_nuclei.isChecked():
kwargs.append("(r'\gamma', 'nucleus', gamma)")
for i in range(self.choices.count()):
kwargs.append(self.choices.widget(i).get_strings())
if kwargs:
return f" choices = {', '.join(kwargs)}\n"
else:
return ''
class ChoiceWidget(QtWidgets.QWidget):
Changed = QtCore.pyqtSignal()
def __init__(self, idx: int, parent=None):
super().__init__(parent=parent)
self._setup_ui()
self.name_line.setText('choice' + str(idx))
self.add_option()
def _setup_ui(self):
layout = QtWidgets.QGridLayout()
layout.setContentsMargins(3, 3, 3, 3)
layout.setHorizontalSpacing(3)
self.name_label = QtWidgets.QLabel('Name', self)
layout.addWidget(self.name_label, 0, 0, 1, 1)
self.name_line = QtWidgets.QLineEdit(self)
self.name_line.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.name_line, 0, 1, 1, 1)
self.disp_label = QtWidgets.QLabel('Disp. name', self)
layout.addWidget(self.disp_label, 1, 0, 1, 1)
self.display_line = QtWidgets.QLineEdit(self)
self.display_line.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.display_line, 1, 1, 1, 1)
self.add_button = QtWidgets.QPushButton('Add option', self)
self.add_button.clicked.connect(self.add_option)
layout.addWidget(self.add_button, 2, 0, 1, 2)
self.remove_button = QtWidgets.QPushButton('Remove option', self)
self.remove_button.clicked.connect(self.remove_option)
layout.addWidget(self.remove_button, 3, 0, 1, 2)
self.table = QtWidgets.QTableWidget(self)
self.table.setColumnCount(3)
self.table.setHorizontalHeaderLabels(['Name', 'Value', 'Type'])
self.table.itemChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.table, 0, 2, 4, 1)
self.setLayout(layout)
def add_option(self):
self.table.blockSignals(True)
row = self.table.rowCount()
self.table.setRowCount(row+1)
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('opt' + str(row)))
lineedit = QtWidgets.QLineEdit()
lineedit.setValidator(validator)
lineedit.setFrame(False)
lineedit.setText('opt'+str(row))
lineedit.textChanged.connect(lambda x: self.Changed.emit())
self.table.setCellWidget(row, 0, lineedit)
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('None'))
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem(''))
cb = QtWidgets.QComboBox()
cb.addItems(['None', 'str', 'float', 'int', 'bool'])
cb.currentIndexChanged.connect(lambda x: self.Changed.emit())
self.table.setCellWidget(row, 2, cb)
self.table.blockSignals(False)
self.Changed.emit()
def remove_option(self):
if self.table.rowCount() > 1:
self.table.blockSignals(True)
self.table.removeRow(self.table.currentRow())
self.table.blockSignals(False)
self.Changed.emit()
def get_parameter(self) -> str:
return f'{self.name_line.text()}={self._make_value(0)!r}'
def get_strings(self) -> str:
opts = []
for i in range(self.table.rowCount()):
name = self.table.item(i, 0).text()
val = self._make_value(i)
opts.append(f'{name!r}: {val!r}')
opts = f"{{{', '.join(opts)}}}"
disp = self.display_line.text()
name = self.name_line.text()
if disp == '':
ret_val = '(' + ', '.join([repr(name), repr(name), opts]) + ')'
else:
ret_val = '(' + ', '.join([repr(name), repr(disp), opts]) + ')'
return ret_val
def _make_value(self, i) -> Any:
dtype = self.table.cellWidget(i, 2).currentIndex()
val = self.table.item(i, 1).text()
cast = [None, str, float, int, bool]
if dtype == 0:
val = None
else:
try:
val = cast[dtype](val)
except:
raise ValueError(f'Invalid argument for {self.table.cellWidget(i, 0).text()}')
return val
class ArgWidget(QtWidgets.QWidget):
Changed = QtCore.pyqtSignal()
def __init__(self, parent=None):
super().__init__(parent=parent)
self._setup_ui()
def _setup_ui(self):
layout = QtWidgets.QGridLayout()
layout.setContentsMargins(3, 3, 3, 3)
layout.setHorizontalSpacing(3)
self.table = QtWidgets.QTableWidget(self)
self.table.setColumnCount(4)
self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
self.table.setRowCount(0)
self.table.setHorizontalHeaderLabels(['Variable', 'Disp. name', 'Lower bound', 'Upper bound'])
self.table.itemChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.table, 0, 0, 1, 3)
self.add_button = QtWidgets.QPushButton('Add parameter', self)
self.add_button.clicked.connect(self.add_variable)
layout.addWidget(self.add_button, 1, 0, 1, 1)
self.rem_button = QtWidgets.QPushButton('Remove parameter', self)
self.rem_button.clicked.connect(self.remove_variable)
layout.addWidget(self.rem_button, 1, 1, 1, 1)
spacer = QtWidgets.QSpacerItem(0, 0)
layout.addItem(spacer, 1, 2, 1, 1)
self.setLayout(layout)
def add_variable(self):
self.table.blockSignals(True)
row = self.table.rowCount()
self.table.setRowCount(row + 1)
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('p' + str(row)))
# arguments cannot start with a number or have spaces
lineedit = QtWidgets.QLineEdit()
lineedit.setValidator(validator)
lineedit.setFrame(False)
lineedit.setText('p'+str(row))
lineedit.textChanged.connect(lambda x: self.Changed.emit())
self.table.setCellWidget(row, 0, lineedit)
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{' + str(row) + '}'))
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem('--'))
self.table.setItem(row, 3, QtWidgets.QTableWidgetItem('--'))
self.table.blockSignals(False)
self.Changed.emit()
def remove_variable(self):
self.table.blockSignals(True)
self.table.removeRow(self.table.currentRow())
self.table.blockSignals(False)
self.Changed.emit()
def get_parameter(self) -> list[str]:
var = []
for row in range(self.table.rowCount()):
var.append(self.table.cellWidget(row, 0).text())
return var
def get_strings(self):
args = []
bnds = []
for row in range(self.table.rowCount()):
args.append(self.table.item(row, 1).text())
lb = self.table.item(row, 2).text()
lb = None if lb in ['--', 'None'] else float(lb)
ub = self.table.item(row, 3).text()
ub = None if ub in ['--', 'None'] else float(ub)
if ub is not None and lb is not None:
if not (lb < ub):
raise ValueError('Some bounds are invalid')
bnds.append(f'({lb}, {ub})')
stringi = f' params = {args}\n'
stringi += f" bounds = [{', '.join(bnds)}]\n"
return stringi
class DescWidget(QtWidgets.QWidget):
Changed = QtCore.pyqtSignal()
def __init__(self, parent=None):
super().__init__(parent=parent)
self._setup_ui()
def _setup_ui(self):
layout = QtWidgets.QGridLayout()
layout.setContentsMargins(3, 3, 3, 3)
layout.setSpacing(3)
self.klass_label = QtWidgets.QLabel('Class', self)
layout.addWidget(self.klass_label, 0, 0, 1, 1)
self.klass_lineedit = QtWidgets.QLineEdit(self)
self.klass_lineedit.setValidator(validator)
self.klass_lineedit.setText('UserClass')
self.klass_lineedit.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.klass_lineedit, 0, 1, 1, 1)
self.name_label = QtWidgets.QLabel('Name', self)
layout.addWidget(self.name_label, 1, 0, 1, 1)
self.name_lineedit = QtWidgets.QLineEdit(self)
self.name_lineedit.setText('Name of function')
self.name_lineedit.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.name_lineedit, 1, 1, 1, 1)
self.group_label = QtWidgets.QLabel('Group', self)
layout.addWidget(self.group_label, 2, 0, 1, 1)
self.group_lineedit = QtWidgets.QLineEdit(self)
self.group_lineedit.setText('User-defined')
self.group_lineedit.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.group_lineedit, 2, 1, 1, 1)
self.eq_label = QtWidgets.QLabel('Disp. equation', self)
layout.addWidget(self.eq_label, 3, 0, 1, 1)
self.eq_lineedit = QtWidgets.QLineEdit(self)
self.eq_lineedit.textChanged.connect(lambda x: self.Changed.emit())
layout.addWidget(self.eq_lineedit, 3, 1, 1, 1)
self.setLayout(layout)
def get_strings(self) -> str:
if self.klass_lineedit.text() == '':
raise ValueError('Class name is empty')
stringi = f'class {self.klass_lineedit.text()}:\n' \
f' name = {self.name_lineedit.text()!r}\n' \
f' group = {self.group_lineedit.text()!r}\n' \
f' equation = {self.eq_lineedit.text()!r}\n'
return stringi
if __name__ == '__main__':
import sys
app = QtWidgets.QApplication([])
win = QUserFitCreator()
win.show()
sys.exit(app.exec())