487 lines
16 KiB
Python
487 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import inspect
|
|
import numbers
|
|
import pathlib
|
|
import re
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from gui_qt.Qt import QtCore, QtWidgets, QtGui
|
|
from gui_qt._py.fitcreationdialog import Ui_Dialog
|
|
from gui_qt.lib.namespace import QNamespaceWidget
|
|
|
|
__all__ = ['QUserFitCreator']
|
|
|
|
validator = QtGui.QRegExpValidator(QtCore.QRegExp('[_A-Za-z][_A-Za-z0-9]*'))
|
|
pattern = re.compile(r'def func\(.*\):', flags=re.MULTILINE)
|
|
|
|
|
|
class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog):
|
|
classCreated = QtCore.pyqtSignal()
|
|
|
|
def __init__(self, filepath: str|pathlib.Path, parent=None):
|
|
super().__init__(parent=parent)
|
|
self.setupUi(self)
|
|
|
|
self.filepath = pathlib.Path(filepath)
|
|
|
|
self.description_widget = DescWidget(self)
|
|
self.args_widget = ArgWidget(self)
|
|
self.kwargs_widget = KwargsWidget(self)
|
|
self.kwargs_widget.Changed.connect(self.update_function)
|
|
self.namespace_widget = QNamespaceWidget(self)
|
|
self.namespace_widget.make_namespace()
|
|
self.namespace_widget.sendKey.connect(self.namespace_made)
|
|
|
|
for b, w in [(self.description_box, self.description_widget), (self.args_box, self.args_widget),
|
|
(self.kwargs_box, self.kwargs_widget), (self.namespace_box, self.namespace_widget)]:
|
|
b.layout().addWidget(w)
|
|
try:
|
|
w.Changed.connect(self.update_function)
|
|
except AttributeError:
|
|
pass
|
|
b.layout().addStretch()
|
|
|
|
self._imports = set()
|
|
|
|
self.update_function()
|
|
|
|
def __call__(self, filepath: str | pathlib.Path):
|
|
self.filepath = pathlib.Path(filepath)
|
|
|
|
return self
|
|
|
|
def update_function(self):
|
|
prev_text = self.plainTextEdit.toPlainText().split('\n')
|
|
func_body = ''
|
|
in_body = False
|
|
for line in prev_text:
|
|
if in_body:
|
|
func_body += line
|
|
continue
|
|
|
|
if pattern.search(line) is not None:
|
|
in_body = True
|
|
|
|
try:
|
|
var = self.args_widget.get_parameter()
|
|
var += self.kwargs_widget.get_parameter()
|
|
|
|
k = ''
|
|
for imps in self._imports:
|
|
if len(imps) == 2:
|
|
k += f'from {imps[0]} import {imps[1]}\n'
|
|
elif imps[0] == 'numpy':
|
|
k += 'import numpy as np\n'
|
|
|
|
if len(self._imports):
|
|
k += '\n\n'
|
|
|
|
k += self.description_widget.get_strings()
|
|
k += self.args_widget.get_strings()
|
|
k += self.kwargs_widget.get_strings()
|
|
|
|
k += '\n @staticmethod\n'
|
|
if var:
|
|
k += f" def func(x, {', '.join(var)}):\n"
|
|
else:
|
|
k += f' def func(x):\n'
|
|
|
|
k += func_body
|
|
|
|
self.plainTextEdit.setPlainText(k)
|
|
except Exception as e:
|
|
QtWidgets.QMessageBox.warning(self, 'Failure', f'Error found: {e.args[0]}')
|
|
|
|
def change_visibility(self):
|
|
sender = self.sender()
|
|
|
|
for box in (self.description_box, self.args_box, self.kwargs_box, self.namespace_box):
|
|
box.blockSignals(True)
|
|
box.setExpansion(sender == box)
|
|
box.blockSignals(False)
|
|
|
|
def namespace_made(self, invalue: str):
|
|
ns = self.namespace_widget.namespace.namespace
|
|
func_value = ns[invalue][0]
|
|
ret_func = ''
|
|
import_name = ''
|
|
|
|
if func_value is None:
|
|
ret_func = invalue
|
|
|
|
elif isinstance(func_value, numbers.Number):
|
|
ret_func = func_value
|
|
|
|
elif isinstance(func_value, np.ufunc):
|
|
self._imports.add(('numpy',))
|
|
ret_func = 'np.'+func_value.__name__ + '(x)'
|
|
|
|
else:
|
|
f_string = ns[invalue][-1]
|
|
args = f_string[f_string.find('('):]
|
|
|
|
if inspect.ismethod(func_value):
|
|
ret_func = func_value.__self__.__name__ + '.func'+args
|
|
import_name = func_value.__self__.__name__
|
|
|
|
elif hasattr(func_value, '__qualname__'):
|
|
import_name = func_value.__qualname__.split('.')[0]
|
|
ret_func = func_value.__qualname__ + args
|
|
|
|
self._imports.add((inspect.getmodule(func_value).__name__, import_name))
|
|
|
|
self.plainTextEdit.insertPlainText(ret_func)
|
|
self.update_function()
|
|
|
|
def accept(self):
|
|
# maybe add a check for correctness
|
|
|
|
with self.filepath.open('a') as f:
|
|
f.write('\n\n')
|
|
f.write(self.plainTextEdit.toPlainText())
|
|
self.classCreated.emit()
|
|
super().accept()
|
|
|
|
|
|
class KwargsWidget(QtWidgets.QWidget):
|
|
Changed = QtCore.pyqtSignal()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent=parent)
|
|
self._num_kwargs = 0
|
|
|
|
self._setup_ui()
|
|
|
|
def _setup_ui(self):
|
|
layout = QtWidgets.QGridLayout()
|
|
layout.setContentsMargins(3, 3, 3, 3)
|
|
layout.setHorizontalSpacing(3)
|
|
|
|
self.use_nuclei = QtWidgets.QCheckBox('Add gyromagnetic ratio', self)
|
|
self.use_nuclei.stateChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.use_nuclei, 0, 0, 1, 3)
|
|
|
|
self.choices = QtWidgets.QTabWidget(self)
|
|
layout.addWidget(self.choices, 1, 0, 1, 3)
|
|
|
|
self.add_choice_button = QtWidgets.QPushButton('Add choice', self)
|
|
self.add_choice_button.clicked.connect(self.add_choice)
|
|
layout.addWidget(self.add_choice_button, 2, 0, 1, 1)
|
|
|
|
self.rem_choice_button = QtWidgets.QPushButton('Remove choice', self)
|
|
self.rem_choice_button.clicked.connect(self.remove_choice)
|
|
layout.addWidget(self.rem_choice_button, 2, 1, 1, 1)
|
|
|
|
self.setLayout(layout)
|
|
|
|
def add_choice(self):
|
|
cnt = self._num_kwargs
|
|
c = ChoiceWidget(cnt, self)
|
|
c.Changed.connect(self.update_choice)
|
|
self.choices.addTab(c, c.name_line.text())
|
|
|
|
self._num_kwargs += 1
|
|
|
|
self.choices.setCurrentIndex(cnt)
|
|
self.Changed.emit()
|
|
|
|
def remove_choice(self):
|
|
cnt = self.choices.currentIndex()
|
|
self.choices.removeTab(cnt)
|
|
self.Changed.emit()
|
|
|
|
def update_choice(self):
|
|
idx = self.choices.currentIndex()
|
|
self.choices.setTabText(idx, self.sender().name_line.text())
|
|
self.Changed.emit()
|
|
|
|
def get_parameter(self):
|
|
if self.use_nuclei.isChecked():
|
|
var = ['nucleus=2.67522128e7']
|
|
else:
|
|
var = []
|
|
var += [self.choices.widget(idx).get_parameter() for idx in range(self.choices.count())]
|
|
|
|
return var
|
|
|
|
def get_strings(self) -> str:
|
|
kwargs = []
|
|
if self.use_nuclei.isChecked():
|
|
kwargs.append(r"(r'\gamma', 'nucleus', gamma)")
|
|
|
|
for i in range(self.choices.count()):
|
|
kwargs.append(self.choices.widget(i).get_strings())
|
|
if kwargs:
|
|
return f" choices = {', '.join(kwargs)}\n"
|
|
else:
|
|
return ''
|
|
|
|
|
|
class ChoiceWidget(QtWidgets.QWidget):
|
|
Changed = QtCore.pyqtSignal()
|
|
|
|
def __init__(self, idx: int, parent=None):
|
|
super().__init__(parent=parent)
|
|
|
|
self._setup_ui()
|
|
|
|
self.name_line.setText('choice' + str(idx))
|
|
self.add_option()
|
|
|
|
def _setup_ui(self):
|
|
layout = QtWidgets.QGridLayout()
|
|
layout.setContentsMargins(3, 3, 3, 3)
|
|
layout.setHorizontalSpacing(3)
|
|
|
|
self.name_label = QtWidgets.QLabel('Name', self)
|
|
layout.addWidget(self.name_label, 0, 0, 1, 1)
|
|
self.name_line = QtWidgets.QLineEdit(self)
|
|
self.name_line.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.name_line, 0, 1, 1, 1)
|
|
|
|
self.disp_label = QtWidgets.QLabel('Disp. name', self)
|
|
layout.addWidget(self.disp_label, 1, 0, 1, 1)
|
|
self.display_line = QtWidgets.QLineEdit(self)
|
|
self.display_line.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.display_line, 1, 1, 1, 1)
|
|
|
|
self.add_button = QtWidgets.QPushButton('Add option', self)
|
|
self.add_button.clicked.connect(self.add_option)
|
|
layout.addWidget(self.add_button, 2, 0, 1, 2)
|
|
|
|
self.remove_button = QtWidgets.QPushButton('Remove option', self)
|
|
self.remove_button.clicked.connect(self.remove_option)
|
|
layout.addWidget(self.remove_button, 3, 0, 1, 2)
|
|
|
|
self.table = QtWidgets.QTableWidget(self)
|
|
self.table.setColumnCount(3)
|
|
self.table.setHorizontalHeaderLabels(['Name', 'Value', 'Type'])
|
|
self.table.itemChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.table, 0, 2, 4, 1)
|
|
|
|
self.setLayout(layout)
|
|
|
|
def add_option(self):
|
|
self.table.blockSignals(True)
|
|
row = self.table.rowCount()
|
|
self.table.setRowCount(row+1)
|
|
|
|
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('opt' + str(row)))
|
|
lineedit = QtWidgets.QLineEdit()
|
|
lineedit.setValidator(validator)
|
|
lineedit.setFrame(False)
|
|
lineedit.setText('opt'+str(row))
|
|
lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
self.table.setCellWidget(row, 0, lineedit)
|
|
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('None'))
|
|
|
|
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem(''))
|
|
cb = QtWidgets.QComboBox()
|
|
cb.addItems(['None', 'str', 'float', 'int', 'bool'])
|
|
cb.currentIndexChanged.connect(lambda x: self.Changed.emit())
|
|
self.table.setCellWidget(row, 2, cb)
|
|
|
|
self.table.blockSignals(False)
|
|
|
|
self.Changed.emit()
|
|
|
|
def remove_option(self):
|
|
if self.table.rowCount() > 1:
|
|
self.table.blockSignals(True)
|
|
self.table.removeRow(self.table.currentRow())
|
|
self.table.blockSignals(False)
|
|
self.Changed.emit()
|
|
|
|
def get_parameter(self) -> str:
|
|
return f'{self.name_line.text()}={self._make_value(0)!r}'
|
|
|
|
def get_strings(self) -> str:
|
|
opts = []
|
|
for i in range(self.table.rowCount()):
|
|
name = self.table.cellWidget(i, 0).text()
|
|
val = self._make_value(i)
|
|
opts.append(f'{name!r}: {val!r}')
|
|
|
|
opts = f"{{{', '.join(opts)}}}"
|
|
|
|
disp = self.display_line.text()
|
|
name = self.name_line.text()
|
|
if disp == '':
|
|
ret_val = '(' + ', '.join([repr(name), repr(name), opts]) + ')'
|
|
else:
|
|
ret_val = '(' + ', '.join([repr(name), repr(disp), opts]) + ')'
|
|
|
|
return ret_val
|
|
|
|
def _make_value(self, i) -> Any:
|
|
dtype = self.table.cellWidget(i, 2).currentIndex()
|
|
val = self.table.item(i, 1).text()
|
|
cast = [None, str, float, int, bool]
|
|
if dtype == 0:
|
|
val = None
|
|
else:
|
|
try:
|
|
val = cast[dtype](val)
|
|
except:
|
|
raise ValueError(f'Invalid argument for {self.table.cellWidget(i, 0).text()}')
|
|
|
|
return val
|
|
|
|
|
|
class ArgWidget(QtWidgets.QWidget):
|
|
Changed = QtCore.pyqtSignal()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent=parent)
|
|
|
|
self._setup_ui()
|
|
|
|
def _setup_ui(self):
|
|
layout = QtWidgets.QGridLayout()
|
|
layout.setContentsMargins(3, 3, 3, 3)
|
|
layout.setHorizontalSpacing(3)
|
|
|
|
self.table = QtWidgets.QTableWidget(self)
|
|
self.table.setColumnCount(4)
|
|
self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
|
|
self.table.setRowCount(0)
|
|
self.table.setHorizontalHeaderLabels(['Variable', 'Disp. name', 'Lower bound', 'Upper bound'])
|
|
self.table.itemChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.table, 0, 0, 1, 3)
|
|
|
|
self.add_button = QtWidgets.QPushButton('Add parameter', self)
|
|
self.add_button.clicked.connect(self.add_variable)
|
|
layout.addWidget(self.add_button, 1, 0, 1, 1)
|
|
|
|
self.rem_button = QtWidgets.QPushButton('Remove parameter', self)
|
|
self.rem_button.clicked.connect(self.remove_variable)
|
|
layout.addWidget(self.rem_button, 1, 1, 1, 1)
|
|
|
|
spacer = QtWidgets.QSpacerItem(0, 0)
|
|
layout.addItem(spacer, 1, 2, 1, 1)
|
|
|
|
self.setLayout(layout)
|
|
|
|
def add_variable(self):
|
|
self.table.blockSignals(True)
|
|
row = self.table.rowCount()
|
|
self.table.setRowCount(row + 1)
|
|
|
|
self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('p' + str(row)))
|
|
# arguments cannot start with a number or have spaces
|
|
lineedit = QtWidgets.QLineEdit()
|
|
lineedit.setValidator(validator)
|
|
lineedit.setFrame(False)
|
|
lineedit.setText('p'+str(row))
|
|
lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
self.table.setCellWidget(row, 0, lineedit)
|
|
|
|
self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{' + str(row) + '}'))
|
|
self.table.setItem(row, 2, QtWidgets.QTableWidgetItem('--'))
|
|
self.table.setItem(row, 3, QtWidgets.QTableWidgetItem('--'))
|
|
self.table.blockSignals(False)
|
|
|
|
self.Changed.emit()
|
|
|
|
def remove_variable(self):
|
|
self.table.blockSignals(True)
|
|
self.table.removeRow(self.table.currentRow())
|
|
self.table.blockSignals(False)
|
|
|
|
self.Changed.emit()
|
|
|
|
def get_parameter(self) -> list[str]:
|
|
var = []
|
|
for row in range(self.table.rowCount()):
|
|
var.append(self.table.cellWidget(row, 0).text())
|
|
|
|
return var
|
|
|
|
def get_strings(self) -> str:
|
|
args = []
|
|
bnds = []
|
|
for row in range(self.table.rowCount()):
|
|
args.append(self.table.item(row, 1).text())
|
|
lb = self.table.item(row, 2).text()
|
|
lb = None if lb in ['--', 'None'] else float(lb)
|
|
|
|
ub = self.table.item(row, 3).text()
|
|
ub = None if ub in ['--', 'None'] else float(ub)
|
|
|
|
if ub is not None and lb is not None:
|
|
if not (lb < ub):
|
|
raise ValueError('Some bounds are invalid')
|
|
|
|
bnds.append(f'({lb}, {ub})')
|
|
|
|
stringi = f' params = {args}\n'
|
|
stringi += f" bounds = [{', '.join(bnds)}]\n"
|
|
|
|
return stringi
|
|
|
|
|
|
class DescWidget(QtWidgets.QWidget):
|
|
Changed = QtCore.pyqtSignal()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent=parent)
|
|
|
|
self._setup_ui()
|
|
|
|
def _setup_ui(self):
|
|
layout = QtWidgets.QGridLayout()
|
|
layout.setContentsMargins(3, 3, 3, 3)
|
|
layout.setSpacing(3)
|
|
|
|
self.klass_label = QtWidgets.QLabel('Class', self)
|
|
layout.addWidget(self.klass_label, 0, 0, 1, 1)
|
|
self.klass_lineedit = QtWidgets.QLineEdit(self)
|
|
self.klass_lineedit.setValidator(validator)
|
|
self.klass_lineedit.setText('UserClass')
|
|
self.klass_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.klass_lineedit, 0, 1, 1, 1)
|
|
|
|
self.name_label = QtWidgets.QLabel('Name', self)
|
|
layout.addWidget(self.name_label, 1, 0, 1, 1)
|
|
self.name_lineedit = QtWidgets.QLineEdit(self)
|
|
self.name_lineedit.setText('Name of function')
|
|
self.name_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.name_lineedit, 1, 1, 1, 1)
|
|
|
|
self.group_label = QtWidgets.QLabel('Group', self)
|
|
layout.addWidget(self.group_label, 2, 0, 1, 1)
|
|
self.group_lineedit = QtWidgets.QLineEdit(self)
|
|
self.group_lineedit.setText('User-defined')
|
|
self.group_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.group_lineedit, 2, 1, 1, 1)
|
|
|
|
self.eq_label = QtWidgets.QLabel('Disp. equation', self)
|
|
layout.addWidget(self.eq_label, 3, 0, 1, 1)
|
|
self.eq_lineedit = QtWidgets.QLineEdit(self)
|
|
self.eq_lineedit.textChanged.connect(lambda x: self.Changed.emit())
|
|
layout.addWidget(self.eq_lineedit, 3, 1, 1, 1)
|
|
|
|
self.setLayout(layout)
|
|
|
|
def get_strings(self) -> str:
|
|
if self.klass_lineedit.text() == '':
|
|
raise ValueError('Class name is empty')
|
|
stringi = f'class {self.klass_lineedit.text()}:\n' \
|
|
f' name = {self.name_lineedit.text()!r}\n' \
|
|
f' type = {self.group_lineedit.text()!r}\n' \
|
|
f" equation = r'{self.eq_lineedit.text()}'\n"
|
|
|
|
return stringi
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
app = QtWidgets.QApplication([])
|
|
win = QUserFitCreator()
|
|
win.show()
|
|
|
|
sys.exit(app.exec())
|