BUGFIX: VFT;
change to src layout
This commit is contained in:
449
src/gui_qt/fit/function_creation_dialog.py
Normal file
449
src/gui_qt/fit/function_creation_dialog.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import numbers
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gui_qt.Qt import QtCore, QtWidgets, QtGui
|
||||
from gui_qt._py.fitcreationdialog import Ui_Dialog
|
||||
from gui_qt.lib.namespace import QNamespaceWidget
|
||||
|
||||
__all__ = ['QUserFitCreator']
|
||||
|
||||
validator = QtGui.QRegExpValidator(QtCore.QRegExp('[A-Za-z]\S*'))
|
||||
|
||||
|
||||
class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog):
|
||||
classCreated = QtCore.pyqtSignal(object)
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
self.setupUi(self)
|
||||
|
||||
self.description_widget = DescWidget(self)
|
||||
self.args_widget = ArgWidget(self)
|
||||
self.kwargs_widget = KwargsWidget(self)
|
||||
self.kwargs_widget.Changed.connect(self.update_function)
|
||||
self.namespace_widget = QNamespaceWidget(self)
|
||||
self.namespace_widget.make_namespace()
|
||||
self.namespace_widget.sendKey.connect(self.namespace_made)
|
||||
|
||||
for b, w in [(self.description_box, self.description_widget), (self.args_box, self.args_widget),
|
||||
(self.kwargs_box, self.kwargs_widget), (self.namespace_box, self.namespace_widget)]:
|
||||
b.layout().addWidget(w)
|
||||
try:
|
||||
w.Changed.connect(self.update_function)
|
||||
except AttributeError:
|
||||
pass
|
||||
b.layout().addStretch()
|
||||
|
||||
self._imports = set()
|
||||
|
||||
self.update_function()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def update_function(self):
|
||||
try:
|
||||
var = self.args_widget.get_parameter()
|
||||
var += self.kwargs_widget.get_parameter()
|
||||
|
||||
k = ''
|
||||
for imps in self._imports:
|
||||
if len(imps) == 2:
|
||||
k += f'from {imps[0]} import {imps[1]}\n'
|
||||
elif imps[0] == 'numpy':
|
||||
k += 'import numpy as np\n'
|
||||
|
||||
if len(self._imports):
|
||||
k += '\n\n'
|
||||
|
||||
k += self.description_widget.get_strings()
|
||||
k += self.args_widget.get_strings()
|
||||
k += self.kwargs_widget.get_strings()
|
||||
|
||||
k += '\n @staticmethod\n'
|
||||
k += f" def func(x, {', '.join(var)}):\n"
|
||||
|
||||
self.plainTextEdit.setPlainText(k)
|
||||
except Exception as e:
|
||||
QtWidgets.QMessageBox.warning(self, 'Failure', f'Error found: {e.args[0]}')
|
||||
|
||||
def change_visibility(self):
|
||||
sender = self.sender()
|
||||
|
||||
for box in (self.description_box, self.args_box, self.kwargs_box, self.namespace_box):
|
||||
box.blockSignals(True)
|
||||
box.setExpansion(sender == box)
|
||||
box.blockSignals(False)
|
||||
|
||||
def namespace_made(self, invalue: str):
|
||||
ns = self.namespace_widget.namespace.namespace
|
||||
func_value = ns[invalue][0]
|
||||
ret_func = ''
|
||||
|
||||
if func_value is None:
|
||||
ret_func = invalue
|
||||
|
||||
elif isinstance(func_value, numbers.Number):
|
||||
ret_func = func_value
|
||||
|
||||
elif isinstance(func_value, np.ufunc):
|
||||
self._imports.add(('numpy',))
|
||||
ret_func = 'np.'+func_value.__name__ + '(x)'
|
||||
|
||||
else:
|
||||
f_string = ns[invalue][-1]
|
||||
args = f_string[f_string.find('('):]
|
||||
if inspect.ismethod(func_value):
|
||||
ret_func = func_value.__self__.__name__ + '.func'+args
|
||||
elif hasattr(func_value, '__qualname__'):
|
||||
ret_func = func_value.__qualname__.split('.')[0]
|
||||
self._imports.add((inspect.getmodule(func_value).__name__, ret_func))
|
||||
|
||||
self.plainTextEdit.insertPlainText(ret_func)
|
||||
self.update_function()
|
||||
|
||||
|
||||
class KwargsWidget(QtWidgets.QWidget):
|
||||
Changed = QtCore.pyqtSignal()
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
self._num_kwargs = 0
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
layout = QtWidgets.QGridLayout()
|
||||
layout.setContentsMargins(3, 3, 3, 3)
|
||||
layout.setHorizontalSpacing(3)
|
||||
|
||||
self.use_nuclei = QtWidgets.QCheckBox('Add gyromagnetic ratio', self)
|
||||
self.use_nuclei.stateChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.use_nuclei, 0, 0, 1, 3)
|
||||
|
||||
self.choices = QtWidgets.QTabWidget(self)
|
||||
layout.addWidget(self.choices, 1, 0, 1, 3)
|
||||
|
||||
self.add_choice_button = QtWidgets.QPushButton('Add choice', self)
|
||||
self.add_choice_button.clicked.connect(self.add_choice)
|
||||
layout.addWidget(self.add_choice_button, 2, 0, 1, 1)
|
||||
|
||||
self.rem_choice_button = QtWidgets.QPushButton('Remove choice', self)
|
||||
self.rem_choice_button.clicked.connect(self.remove_choice)
|
||||
layout.addWidget(self.rem_choice_button, 2, 1, 1, 1)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def add_choice(self):
|
||||
cnt = self._num_kwargs
|
||||
c = ChoiceWidget(cnt, self)
|
||||
c.Changed.connect(self.update_choice)
|
||||
self.choices.addTab(c, c.name_line.text())
|
||||
|
||||
self._num_kwargs += 1
|
||||
|
||||
self.choices.setCurrentIndex(cnt)
|
||||
self.Changed.emit()
|
||||
|
||||
def remove_choice(self):
|
||||
cnt = self.choices.currentIndex()
|
||||
self.choices.removeTab(cnt)
|
||||
self.Changed.emit()
|
||||
|
||||
def update_choice(self):
|
||||
idx = self.choices.currentIndex()
|
||||
self.choices.setTabText(idx, self.sender().name_line.text())
|
||||
self.Changed.emit()
|
||||
|
||||
def get_parameter(self):
|
||||
if self.use_nuclei.isChecked():
|
||||
var = ['nucleus=2.67522128e7']
|
||||
else:
|
||||
var = []
|
||||
var += [self.choices.widget(idx).get_parameter() for idx in range(self.choices.count())]
|
||||
|
||||
return var
|
||||
|
||||
def get_strings(self) -> str:
|
||||
kwargs = []
|
||||
if self.use_nuclei.isChecked():
|
||||
kwargs.append("(r'\gamma', 'nucleus', gamma)")
|
||||
|
||||
for i in range(self.choices.count()):
|
||||
kwargs.append(self.choices.widget(i).get_strings())
|
||||
if kwargs:
|
||||
return f" choices = {', '.join(kwargs)}\n"
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
class ChoiceWidget(QtWidgets.QWidget):
|
||||
Changed = QtCore.pyqtSignal()
|
||||
|
||||
def __init__(self, idx: int, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
self.name_line.setText('choice' + str(idx))
|
||||
self.add_option()
|
||||
|
||||
def _setup_ui(self):
|
||||
layout = QtWidgets.QGridLayout()
|
||||
layout.setContentsMargins(3, 3, 3, 3)
|
||||
layout.setHorizontalSpacing(3)
|
||||
|
||||
self.name_label = QtWidgets.QLabel('Name', self)
|
||||
layout.addWidget(self.name_label, 0, 0, 1, 1)
|
||||
self.name_line = QtWidgets.QLineEdit(self)
|
||||
self.name_line.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.name_line, 0, 1, 1, 1)
|
||||
|
||||
self.disp_label = QtWidgets.QLabel('Disp. name', self)
|
||||
layout.addWidget(self.disp_label, 1, 0, 1, 1)
|
||||
self.display_line = QtWidgets.QLineEdit(self)
|
||||
self.display_line.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.display_line, 1, 1, 1, 1)
|
||||
|
||||
self.add_button = QtWidgets.QPushButton('Add option', self)
|
||||
self.add_button.clicked.connect(self.add_option)
|
||||
layout.addWidget(self.add_button, 2, 0, 1, 2)
|
||||
|
||||
self.remove_button = QtWidgets.QPushButton('Remove option', self)
|
||||
self.remove_button.clicked.connect(self.remove_option)
|
||||
layout.addWidget(self.remove_button, 3, 0, 1, 2)
|
||||
|
||||
self.table = QtWidgets.QTableWidget(self)
|
||||
self.table.setColumnCount(3)
|
||||
self.table.setHorizontalHeaderLabels(['Name', 'Value', 'Type'])
|
||||
self.table.itemChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.table, 0, 2, 4, 1)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def add_option(self):
|
||||
self.table.blockSignals(True)
|
||||
row = self.table.rowCount()
|
||||
self.table.setRowCount(row+1)
|
||||
|
||||
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('opt' + str(row)))
|
||||
lineedit = QtWidgets.QLineEdit()
|
||||
lineedit.setValidator(validator)
|
||||
lineedit.setFrame(False)
|
||||
lineedit.setText('opt'+str(row))
|
||||
lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
self.table.setCellWidget(row, 0, lineedit)
|
||||
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('None'))
|
||||
|
||||
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem(''))
|
||||
cb = QtWidgets.QComboBox()
|
||||
cb.addItems(['None', 'str', 'float', 'int', 'bool'])
|
||||
cb.currentIndexChanged.connect(lambda x: self.Changed.emit())
|
||||
self.table.setCellWidget(row, 2, cb)
|
||||
|
||||
self.table.blockSignals(False)
|
||||
|
||||
self.Changed.emit()
|
||||
|
||||
def remove_option(self):
|
||||
if self.table.rowCount() > 1:
|
||||
self.table.blockSignals(True)
|
||||
self.table.removeRow(self.table.currentRow())
|
||||
self.table.blockSignals(False)
|
||||
self.Changed.emit()
|
||||
|
||||
def get_parameter(self) -> str:
|
||||
return f'{self.name_line.text()}={self._make_value(0)!r}'
|
||||
|
||||
def get_strings(self) -> str:
|
||||
opts = []
|
||||
for i in range(self.table.rowCount()):
|
||||
name = self.table.item(i, 0).text()
|
||||
val = self._make_value(i)
|
||||
opts.append(f'{name!r}: {val!r}')
|
||||
|
||||
opts = f"{{{', '.join(opts)}}}"
|
||||
|
||||
disp = self.display_line.text()
|
||||
name = self.name_line.text()
|
||||
if disp == '':
|
||||
ret_val = '(' + ', '.join([repr(name), repr(name), opts]) + ')'
|
||||
else:
|
||||
ret_val = '(' + ', '.join([repr(name), repr(disp), opts]) + ')'
|
||||
|
||||
return ret_val
|
||||
|
||||
def _make_value(self, i) -> Any:
|
||||
dtype = self.table.cellWidget(i, 2).currentIndex()
|
||||
val = self.table.item(i, 1).text()
|
||||
cast = [None, str, float, int, bool]
|
||||
if dtype == 0:
|
||||
val = None
|
||||
else:
|
||||
try:
|
||||
val = cast[dtype](val)
|
||||
except:
|
||||
raise ValueError(f'Invalid argument for {self.table.cellWidget(i, 0).text()}')
|
||||
|
||||
return val
|
||||
|
||||
|
||||
class ArgWidget(QtWidgets.QWidget):
|
||||
Changed = QtCore.pyqtSignal()
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
layout = QtWidgets.QGridLayout()
|
||||
layout.setContentsMargins(3, 3, 3, 3)
|
||||
layout.setHorizontalSpacing(3)
|
||||
|
||||
self.table = QtWidgets.QTableWidget(self)
|
||||
self.table.setColumnCount(4)
|
||||
self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
|
||||
self.table.setRowCount(0)
|
||||
self.table.setHorizontalHeaderLabels(['Variable', 'Disp. name', 'Lower bound', 'Upper bound'])
|
||||
self.table.itemChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.table, 0, 0, 1, 3)
|
||||
|
||||
self.add_button = QtWidgets.QPushButton('Add parameter', self)
|
||||
self.add_button.clicked.connect(self.add_variable)
|
||||
layout.addWidget(self.add_button, 1, 0, 1, 1)
|
||||
|
||||
self.rem_button = QtWidgets.QPushButton('Remove parameter', self)
|
||||
self.rem_button.clicked.connect(self.remove_variable)
|
||||
layout.addWidget(self.rem_button, 1, 1, 1, 1)
|
||||
|
||||
spacer = QtWidgets.QSpacerItem(0, 0)
|
||||
layout.addItem(spacer, 1, 2, 1, 1)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def add_variable(self):
|
||||
self.table.blockSignals(True)
|
||||
row = self.table.rowCount()
|
||||
self.table.setRowCount(row + 1)
|
||||
|
||||
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('p' + str(row)))
|
||||
# arguments cannot start with a number or have spaces
|
||||
lineedit = QtWidgets.QLineEdit()
|
||||
lineedit.setValidator(validator)
|
||||
lineedit.setFrame(False)
|
||||
lineedit.setText('p'+str(row))
|
||||
lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
self.table.setCellWidget(row, 0, lineedit)
|
||||
|
||||
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{' + str(row) + '}'))
|
||||
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem('--'))
|
||||
self.table.setItem(row, 3, QtWidgets.QTableWidgetItem('--'))
|
||||
self.table.blockSignals(False)
|
||||
|
||||
self.Changed.emit()
|
||||
|
||||
def remove_variable(self):
|
||||
self.table.blockSignals(True)
|
||||
self.table.removeRow(self.table.currentRow())
|
||||
self.table.blockSignals(False)
|
||||
|
||||
self.Changed.emit()
|
||||
|
||||
def get_parameter(self) -> list[str]:
|
||||
var = []
|
||||
for row in range(self.table.rowCount()):
|
||||
var.append(self.table.cellWidget(row, 0).text())
|
||||
|
||||
return var
|
||||
|
||||
def get_strings(self):
|
||||
args = []
|
||||
bnds = []
|
||||
for row in range(self.table.rowCount()):
|
||||
args.append(self.table.item(row, 1).text())
|
||||
lb = self.table.item(row, 2).text()
|
||||
lb = None if lb in ['--', 'None'] else float(lb)
|
||||
|
||||
ub = self.table.item(row, 3).text()
|
||||
ub = None if ub in ['--', 'None'] else float(ub)
|
||||
|
||||
if ub is not None and lb is not None:
|
||||
if not (lb < ub):
|
||||
raise ValueError('Some bounds are invalid')
|
||||
|
||||
bnds.append(f'({lb}, {ub})')
|
||||
|
||||
stringi = f' params = {args}\n'
|
||||
stringi += f" bounds = [{', '.join(bnds)}]\n"
|
||||
|
||||
return stringi
|
||||
|
||||
|
||||
class DescWidget(QtWidgets.QWidget):
|
||||
Changed = QtCore.pyqtSignal()
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent=parent)
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
layout = QtWidgets.QGridLayout()
|
||||
layout.setContentsMargins(3, 3, 3, 3)
|
||||
layout.setSpacing(3)
|
||||
|
||||
self.klass_label = QtWidgets.QLabel('Class', self)
|
||||
layout.addWidget(self.klass_label, 0, 0, 1, 1)
|
||||
self.klass_lineedit = QtWidgets.QLineEdit(self)
|
||||
self.klass_lineedit.setValidator(validator)
|
||||
self.klass_lineedit.setText('UserClass')
|
||||
self.klass_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.klass_lineedit, 0, 1, 1, 1)
|
||||
|
||||
self.name_label = QtWidgets.QLabel('Name', self)
|
||||
layout.addWidget(self.name_label, 1, 0, 1, 1)
|
||||
self.name_lineedit = QtWidgets.QLineEdit(self)
|
||||
self.name_lineedit.setText('Name of function')
|
||||
self.name_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.name_lineedit, 1, 1, 1, 1)
|
||||
|
||||
self.group_label = QtWidgets.QLabel('Group', self)
|
||||
layout.addWidget(self.group_label, 2, 0, 1, 1)
|
||||
self.group_lineedit = QtWidgets.QLineEdit(self)
|
||||
self.group_lineedit.setText('User-defined')
|
||||
self.group_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.group_lineedit, 2, 1, 1, 1)
|
||||
|
||||
self.eq_label = QtWidgets.QLabel('Disp. equation', self)
|
||||
layout.addWidget(self.eq_label, 3, 0, 1, 1)
|
||||
self.eq_lineedit = QtWidgets.QLineEdit(self)
|
||||
self.eq_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
||||
layout.addWidget(self.eq_lineedit, 3, 1, 1, 1)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def get_strings(self) -> str:
|
||||
if self.klass_lineedit.text() == '':
|
||||
raise ValueError('Class name is empty')
|
||||
stringi = f'class {self.klass_lineedit.text()}:\n' \
|
||||
f' name = {self.name_lineedit.text()!r}\n' \
|
||||
f' group = {self.group_lineedit.text()!r}\n' \
|
||||
f' equation = {self.eq_lineedit.text()!r}\n'
|
||||
|
||||
return stringi
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
app = QtWidgets.QApplication([])
|
||||
win = QUserFitCreator()
|
||||
win.show()
|
||||
|
||||
sys.exit(app.exec())
|
Reference in New Issue
Block a user