from __future__ import annotations import inspect import numbers import pathlib import re from typing import Any import numpy as np from gui_qt.Qt import QtCore, QtWidgets, QtGui from gui_qt._py.fitcreationdialog import Ui_Dialog from gui_qt.lib.namespace import QNamespaceWidget __all__ = ['QUserFitCreator'] validator = QtGui.QRegExpValidator(QtCore.QRegExp('[A-Za-z]\S*')) pattern = re.compile(r'def func\(.*\):', flags=re.MULTILINE) class QUserFitCreator(QtWidgets.QDialog, Ui_Dialog): classCreated = QtCore.pyqtSignal() def __init__(self, filepath: str|pathlib.Path, parent=None): super().__init__(parent=parent) self.setupUi(self) self.filepath = pathlib.Path(filepath) self.description_widget = DescWidget(self) self.args_widget = ArgWidget(self) self.kwargs_widget = KwargsWidget(self) self.kwargs_widget.Changed.connect(self.update_function) self.namespace_widget = QNamespaceWidget(self) self.namespace_widget.make_namespace() self.namespace_widget.sendKey.connect(self.namespace_made) for b, w in [(self.description_box, self.description_widget), (self.args_box, self.args_widget), (self.kwargs_box, self.kwargs_widget), (self.namespace_box, self.namespace_widget)]: b.layout().addWidget(w) try: w.Changed.connect(self.update_function) except AttributeError: pass b.layout().addStretch() self._imports = set() self.update_function() def __call__(self, filepath: str|pathlib.Path): self.filepath = pathlib.Path(filepath) return self def update_function(self): prev_text = self.plainTextEdit.toPlainText().split('\n') func_body = '' in_body = False for line in prev_text: if in_body: func_body += line continue if pattern.search(line) is not None: in_body = True try: var = self.args_widget.get_parameter() var += self.kwargs_widget.get_parameter() k = '' for imps in self._imports: if len(imps) == 2: k += f'from {imps[0]} import {imps[1]}\n' elif imps[0] == 'numpy': k += 'import numpy as np\n' if len(self._imports): k += '\n\n' k += self.description_widget.get_strings() k += self.args_widget.get_strings() k += self.kwargs_widget.get_strings() k += '\n @staticmethod\n' if var: k += f" def func(x, {', '.join(var)}):\n" else: k += f' def func(x):\n' k += func_body self.plainTextEdit.setPlainText(k) except Exception as e: QtWidgets.QMessageBox.warning(self, 'Failure', f'Error found: {e.args[0]}') def change_visibility(self): sender = self.sender() for box in (self.description_box, self.args_box, self.kwargs_box, self.namespace_box): box.blockSignals(True) box.setExpansion(sender == box) box.blockSignals(False) def namespace_made(self, invalue: str): ns = self.namespace_widget.namespace.namespace func_value = ns[invalue][0] ret_func = '' import_name = '' if func_value is None: ret_func = invalue elif isinstance(func_value, numbers.Number): ret_func = func_value elif isinstance(func_value, np.ufunc): self._imports.add(('numpy',)) ret_func = 'np.'+func_value.__name__ + '(x)' else: f_string = ns[invalue][-1] args = f_string[f_string.find('('):] if inspect.ismethod(func_value): ret_func = func_value.__self__.__name__ + '.func'+args import_name = func_value.__self__.__name__ elif hasattr(func_value, '__qualname__'): import_name = func_value.__qualname__.split('.')[0] ret_func = func_value.__qualname__ + args self._imports.add((inspect.getmodule(func_value).__name__, import_name)) self.plainTextEdit.insertPlainText(ret_func) self.update_function() def accept(self): # maybe add a check for correctness with self.filepath.open('a') as f: f.write('\n\n') f.write(self.plainTextEdit.toPlainText()) self.classCreated.emit() super().accept() class KwargsWidget(QtWidgets.QWidget): Changed = QtCore.pyqtSignal() def __init__(self, parent=None): super().__init__(parent=parent) self._num_kwargs = 0 self._setup_ui() def _setup_ui(self): layout = QtWidgets.QGridLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setHorizontalSpacing(3) self.use_nuclei = QtWidgets.QCheckBox('Add gyromagnetic ratio', self) self.use_nuclei.stateChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.use_nuclei, 0, 0, 1, 3) self.choices = QtWidgets.QTabWidget(self) layout.addWidget(self.choices, 1, 0, 1, 3) self.add_choice_button = QtWidgets.QPushButton('Add choice', self) self.add_choice_button.clicked.connect(self.add_choice) layout.addWidget(self.add_choice_button, 2, 0, 1, 1) self.rem_choice_button = QtWidgets.QPushButton('Remove choice', self) self.rem_choice_button.clicked.connect(self.remove_choice) layout.addWidget(self.rem_choice_button, 2, 1, 1, 1) self.setLayout(layout) def add_choice(self): cnt = self._num_kwargs c = ChoiceWidget(cnt, self) c.Changed.connect(self.update_choice) self.choices.addTab(c, c.name_line.text()) self._num_kwargs += 1 self.choices.setCurrentIndex(cnt) self.Changed.emit() def remove_choice(self): cnt = self.choices.currentIndex() self.choices.removeTab(cnt) self.Changed.emit() def update_choice(self): idx = self.choices.currentIndex() self.choices.setTabText(idx, self.sender().name_line.text()) self.Changed.emit() def get_parameter(self): if self.use_nuclei.isChecked(): var = ['nucleus=2.67522128e7'] else: var = [] var += [self.choices.widget(idx).get_parameter() for idx in range(self.choices.count())] return var def get_strings(self) -> str: kwargs = [] if self.use_nuclei.isChecked(): kwargs.append("(r'\gamma', 'nucleus', gamma)") for i in range(self.choices.count()): kwargs.append(self.choices.widget(i).get_strings()) if kwargs: return f" choices = {', '.join(kwargs)}\n" else: return '' class ChoiceWidget(QtWidgets.QWidget): Changed = QtCore.pyqtSignal() def __init__(self, idx: int, parent=None): super().__init__(parent=parent) self._setup_ui() self.name_line.setText('choice' + str(idx)) self.add_option() def _setup_ui(self): layout = QtWidgets.QGridLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setHorizontalSpacing(3) self.name_label = QtWidgets.QLabel('Name', self) layout.addWidget(self.name_label, 0, 0, 1, 1) self.name_line = QtWidgets.QLineEdit(self) self.name_line.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.name_line, 0, 1, 1, 1) self.disp_label = QtWidgets.QLabel('Disp. name', self) layout.addWidget(self.disp_label, 1, 0, 1, 1) self.display_line = QtWidgets.QLineEdit(self) self.display_line.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.display_line, 1, 1, 1, 1) self.add_button = QtWidgets.QPushButton('Add option', self) self.add_button.clicked.connect(self.add_option) layout.addWidget(self.add_button, 2, 0, 1, 2) self.remove_button = QtWidgets.QPushButton('Remove option', self) self.remove_button.clicked.connect(self.remove_option) layout.addWidget(self.remove_button, 3, 0, 1, 2) self.table = QtWidgets.QTableWidget(self) self.table.setColumnCount(3) self.table.setHorizontalHeaderLabels(['Name', 'Value', 'Type']) self.table.itemChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.table, 0, 2, 4, 1) self.setLayout(layout) def add_option(self): self.table.blockSignals(True) row = self.table.rowCount() self.table.setRowCount(row+1) self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('opt' + str(row))) lineedit = QtWidgets.QLineEdit() lineedit.setValidator(validator) lineedit.setFrame(False) lineedit.setText('opt'+str(row)) lineedit.textChanged.connect(lambda x: self.Changed.emit()) self.table.setCellWidget(row, 0, lineedit) self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('None')) self.table.setItem(row, 2, QtWidgets.QTableWidgetItem('')) cb = QtWidgets.QComboBox() cb.addItems(['None', 'str', 'float', 'int', 'bool']) cb.currentIndexChanged.connect(lambda x: self.Changed.emit()) self.table.setCellWidget(row, 2, cb) self.table.blockSignals(False) self.Changed.emit() def remove_option(self): if self.table.rowCount() > 1: self.table.blockSignals(True) self.table.removeRow(self.table.currentRow()) self.table.blockSignals(False) self.Changed.emit() def get_parameter(self) -> str: return f'{self.name_line.text()}={self._make_value(0)!r}' def get_strings(self) -> str: opts = [] for i in range(self.table.rowCount()): name = self.table.item(i, 0).text() val = self._make_value(i) opts.append(f'{name!r}: {val!r}') opts = f"{{{', '.join(opts)}}}" disp = self.display_line.text() name = self.name_line.text() if disp == '': ret_val = '(' + ', '.join([repr(name), repr(name), opts]) + ')' else: ret_val = '(' + ', '.join([repr(name), repr(disp), opts]) + ')' return ret_val def _make_value(self, i) -> Any: dtype = self.table.cellWidget(i, 2).currentIndex() val = self.table.item(i, 1).text() cast = [None, str, float, int, bool] if dtype == 0: val = None else: try: val = cast[dtype](val) except: raise ValueError(f'Invalid argument for {self.table.cellWidget(i, 0).text()}') return val class ArgWidget(QtWidgets.QWidget): Changed = QtCore.pyqtSignal() def __init__(self, parent=None): super().__init__(parent=parent) self._setup_ui() def _setup_ui(self): layout = QtWidgets.QGridLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setHorizontalSpacing(3) self.table = QtWidgets.QTableWidget(self) self.table.setColumnCount(4) self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows) self.table.setRowCount(0) self.table.setHorizontalHeaderLabels(['Variable', 'Disp. name', 'Lower bound', 'Upper bound']) self.table.itemChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.table, 0, 0, 1, 3) self.add_button = QtWidgets.QPushButton('Add parameter', self) self.add_button.clicked.connect(self.add_variable) layout.addWidget(self.add_button, 1, 0, 1, 1) self.rem_button = QtWidgets.QPushButton('Remove parameter', self) self.rem_button.clicked.connect(self.remove_variable) layout.addWidget(self.rem_button, 1, 1, 1, 1) spacer = QtWidgets.QSpacerItem(0, 0) layout.addItem(spacer, 1, 2, 1, 1) self.setLayout(layout) def add_variable(self): self.table.blockSignals(True) row = self.table.rowCount() self.table.setRowCount(row + 1) self.table.setItem(row, 0, QtWidgets.QTableWidgetItem('p' + str(row))) # arguments cannot start with a number or have spaces lineedit = QtWidgets.QLineEdit() lineedit.setValidator(validator) lineedit.setFrame(False) lineedit.setText('p'+str(row)) lineedit.textChanged.connect(lambda x: self.Changed.emit()) self.table.setCellWidget(row, 0, lineedit) self.table.setItem(row, 1, QtWidgets.QTableWidgetItem('p_{' + str(row) + '}')) self.table.setItem(row, 2, QtWidgets.QTableWidgetItem('--')) self.table.setItem(row, 3, QtWidgets.QTableWidgetItem('--')) self.table.blockSignals(False) self.Changed.emit() def remove_variable(self): self.table.blockSignals(True) self.table.removeRow(self.table.currentRow()) self.table.blockSignals(False) self.Changed.emit() def get_parameter(self) -> list[str]: var = [] for row in range(self.table.rowCount()): var.append(self.table.cellWidget(row, 0).text()) return var def get_strings(self) -> str: args = [] bnds = [] for row in range(self.table.rowCount()): args.append(self.table.item(row, 1).text()) lb = self.table.item(row, 2).text() lb = None if lb in ['--', 'None'] else float(lb) ub = self.table.item(row, 3).text() ub = None if ub in ['--', 'None'] else float(ub) if ub is not None and lb is not None: if not (lb < ub): raise ValueError('Some bounds are invalid') bnds.append(f'({lb}, {ub})') stringi = f' params = {args}\n' stringi += f" bounds = [{', '.join(bnds)}]\n" return stringi class DescWidget(QtWidgets.QWidget): Changed = QtCore.pyqtSignal() def __init__(self, parent=None): super().__init__(parent=parent) self._setup_ui() def _setup_ui(self): layout = QtWidgets.QGridLayout() layout.setContentsMargins(3, 3, 3, 3) layout.setSpacing(3) self.klass_label = QtWidgets.QLabel('Class', self) layout.addWidget(self.klass_label, 0, 0, 1, 1) self.klass_lineedit = QtWidgets.QLineEdit(self) self.klass_lineedit.setValidator(validator) self.klass_lineedit.setText('UserClass') self.klass_lineedit.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.klass_lineedit, 0, 1, 1, 1) self.name_label = QtWidgets.QLabel('Name', self) layout.addWidget(self.name_label, 1, 0, 1, 1) self.name_lineedit = QtWidgets.QLineEdit(self) self.name_lineedit.setText('Name of function') self.name_lineedit.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.name_lineedit, 1, 1, 1, 1) self.group_label = QtWidgets.QLabel('Group', self) layout.addWidget(self.group_label, 2, 0, 1, 1) self.group_lineedit = QtWidgets.QLineEdit(self) self.group_lineedit.setText('User-defined') self.group_lineedit.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.group_lineedit, 2, 1, 1, 1) self.eq_label = QtWidgets.QLabel('Disp. equation', self) layout.addWidget(self.eq_label, 3, 0, 1, 1) self.eq_lineedit = QtWidgets.QLineEdit(self) self.eq_lineedit.textChanged.connect(lambda x: self.Changed.emit()) layout.addWidget(self.eq_lineedit, 3, 1, 1, 1) self.setLayout(layout) def get_strings(self) -> str: if self.klass_lineedit.text() == '': raise ValueError('Class name is empty') stringi = f'class {self.klass_lineedit.text()}:\n' \ f' name = {self.name_lineedit.text()!r}\n' \ f' type = {self.group_lineedit.text()!r}\n' \ f' equation = {self.eq_lineedit.text()!r}\n' return stringi if __name__ == '__main__': import sys app = QtWidgets.QApplication([]) win = QUserFitCreator() win.show() sys.exit(app.exec())