from __future__ import annotations from functools import reduce from itertools import count, cycle from operator import add from string import ascii_letters import numpy as np from pyqtgraph import mkPen from nmreval.fit._meta import MultiModel, ModelFactory from nmreval.fit.model import Model from nmreval.fit.parameter import Parameters from nmreval.fit.result import FitResult from .fit_forms import FitTableWidget from .fit_parameter import QFitParameterWidget from ..lib import Relations from ..lib.pg_objects import PlotItem from ..Qt import QtGui, QtCore, QtWidgets from .._py.fitdialog import Ui_FitDialog class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): func_cnt = count() model_cnt = cycle(ascii_letters) preview_num = 201 preview_emit = QtCore.pyqtSignal(dict, int, bool) fitStartSig = QtCore.pyqtSignal(dict, list, dict) abortFit = QtCore.pyqtSignal() def __init__(self, mgmt=None, parent=None): super().__init__(parent=parent) self.setupUi(self) self.parameters = {} self.preview_lines = [] self._current_function = None self.param_widgets = {} self._management = mgmt self._current_model = next(QFitDialog.model_cnt) self.show_combobox.setItemData(0, self._current_model, QtCore.Qt.ItemDataRole.UserRole) self.default_combobox.setItemData(0, self._current_model, QtCore.Qt.ItemDataRole.UserRole) self.data_table = FitTableWidget(self.data_widget) self.data_widget.addWidget(self.data_table) self.data_widget.setText('Data') self.models = {} self._func_list = {} self._complex = {} self.connected_figure = None self.model_frame.hide() self.preview_button.hide() self.abort_button.clicked.connect(lambda: self.abortFit.emit()) self.functionwidget.newFunction.connect(self.add_function) self.functionwidget.showFunction.connect(self.show_function_parameter) self.functionwidget.itemRemoved.connect(self.remove_function) self.read_and_load_functions = self.functionwidget.read_and_load_functions @QtCore.pyqtSlot(int, int) def add_function(self, function_idx: int, function_id: int): self.show_function_parameter(function_id, function_idx) self.newmodel_button.setEnabled(True) @QtCore.pyqtSlot(int) def remove_function(self, idx: int): """ Remove function and children from tree and dictionary """ w = self.param_widgets[idx] self.stackedWidget.removeWidget(w) w.deleteLater() del self.param_widgets[idx] self._current_function = None if len(self.param_widgets) == 0: # empty model self.newmodel_button.setEnabled(False) self.deletemodel_button.setEnabled(False) @QtCore.pyqtSlot(int) def show_function_parameter(self, function_id: int, function_idx: int = None): """ Display parameter associated with selected function. """ if function_id in self.param_widgets: dialog = self.param_widgets[function_id] else: # create new widget for function if function_idx is not None: function = self.functionwidget.functions[function_idx] else: raise ValueError('No function index given') if function is None: return dialog = QFitParameterWidget(self.stackedWidget) data_names = self.data_table.data_list(include_name=True) dialog.set_function(function, function_idx) dialog.load(data_names) dialog.value_requested.connect(self.look_value) self.stackedWidget.addWidget(dialog) self.param_widgets[function_id] = dialog self.stackedWidget.setCurrentWidget(dialog) # collect parameter names etc. to allow linkage self._func_list[self._current_model] = self.functionwidget.get_parameter_list() # dialog.set_links(self._func_list) # show same tab (general parameter/Data parameter) tab_idx = 0 if self._current_function is not None: tab_idx = self.param_widgets[self._current_function].tabWidget.currentIndex() dialog.tabWidget.setCurrentIndex(tab_idx) self._current_function = function_id def look_value(self, idx: int): func_widget = self.param_widgets[self._current_function] set_ids = [func_widget.comboBox.itemData(i) for i in range(func_widget.comboBox.count())] for s in set_ids: func_widget.data_values[s][idx] = self._management[s].value func_widget.change_data(func_widget.comboBox.currentIndex()) def get_functions(self): """ update functions, parameters""" self.models[self._current_model] = self.functionwidget.get_functions() self._complex[self._current_model] = self.functionwidget.get_complex_state() self._func_list[self._current_model] = self.functionwidget.get_parameter_list() def load(self, ids: list[str]): """ Add name and id of dataset to list. """ self.data_table.load(ids) # deselect all fit sets for i in range(self.data_table.rowCount()): data_id = self.data_table.item(i, 0).data(QtCore.Qt.ItemDataRole.UserRole+1) if self._management[data_id].mode == 'fit' or self._management[data_id].has_relation(Relations.isFitPartOf): self.data_table.item(i, 0).setCheckState(QtCore.Qt.CheckState.Unchecked) if self.models: for m in self.models.keys(): self.data_table.add_model(m) else: self.data_table.add_model(self._current_model) for dialog in self.param_widgets.values(): dialog.load(ids) @QtCore.pyqtSlot(name='on_newmodel_button_clicked') def make_new_model(self): """ Save model with all its functions in dictionary and adjust gui. """ self.deletemodel_button.setEnabled(True) self.model_frame.show() idx = next(QFitDialog.model_cnt) self.data_table.add_model(idx) self.default_combobox.addItem('Model '+idx, userData=idx) self.show_combobox.addItem('Model '+idx, userData=idx) self.show_combobox.setItemData(self.show_combobox.count()-1, idx, QtCore.Qt.ItemDataRole.UserRole) self.show_combobox.setCurrentIndex(self.show_combobox.count()-1) self._current_model = idx self.stackedWidget.setCurrentIndex(0) @QtCore.pyqtSlot(int, name='on_show_combobox_currentIndexChanged') def change_model(self, idx: int): """ Save old model and display new model. """ self.get_functions() self.functionwidget.clear() self._current_model = self.show_combobox.itemData(idx, QtCore.Qt.ItemDataRole.UserRole) if self._current_model in self.models and len(self.models[self._current_model]): for el in self.models[self._current_model]: self.functionwidget.add_function(**el) self.functionwidget.set_complex_state(self._complex[self._current_model]) else: self.stackedWidget.setCurrentIndex(0) @QtCore.pyqtSlot(name='on_deletemodel_button_clicked') def remove_model(self): model_id = self._current_model self.show_combobox.removeItem(self.show_combobox.findData(model_id)) self.default_combobox.removeItem(self.default_combobox.findData(model_id)) for m in self.models[model_id]: func_id = m['cnt'] self.stackedWidget.removeWidget(self.param_widgets[func_id]) self.param_widgets.pop(func_id) self._complex.pop(model_id) self._func_list.pop(model_id) self.models.pop(model_id) self.data_table.remove_model(model_id) if len(self.models) == 1: self.model_frame.hide() def _prepare(self, model: list, function_use: list = None, parameter: dict = None, add_idx: bool = False, cnt: int = 0) -> tuple[dict, int]: if parameter is None: parameter = { 'data_parameter': {}, 'global_parameter': [], 'links': [], 'color': [], } for i, f in enumerate(model): if not f['active']: continue try: p, glob = self.param_widgets[f['cnt']].get_parameter(function_use) except ValueError as e: _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', str(e), QtWidgets.QMessageBox.Ok) return {}, -1 parameter['color'].append(f['color']) parameter['global_parameter'].extend(glob) cnt = f['cnt'] for p_k, v_k in p.items(): if add_idx: kw_k = {f'{k}_{cnt}': v for k, v in v_k[1].items()} else: kw_k = v_k[1] if p_k in parameter['data_parameter']: params, kw = parameter['data_parameter'][p_k] params += v_k[0] kw.update(kw_k) else: parameter['data_parameter'][p_k] = (v_k[0], kw_k) if add_idx: cnt += 1 if f['children']: # recurse for children _, cnt = self._prepare(f['children'], parameter=parameter, add_idx=add_idx, cnt=cnt) return parameter, cnt @QtCore.pyqtSlot(name='on_fit_button_clicked') def start_fit(self): self.get_functions() data = self.data_table.collect_data(default=self.default_combobox.currentData()) func_dict = {} for model_name, model_parameter in self.models.items(): func, order, param_len, _ = ModelFactory.create_from_list(model_parameter) multiple_funcs = isinstance(func, MultiModel) if func is None: continue func = Model(func) if model_name in data: parameter, _ = self._prepare(model_parameter, function_use=data[model_name], add_idx=multiple_funcs) if parameter is None or not isinstance(parameter, dict): return if ('data_parameter' not in parameter) or ('global_parameter' not in parameter): return for (data_parameter, _) in parameter['data_parameter'].values(): for pname, param in zip(func.params, data_parameter): param.name = pname if self._complex[model_name] is not None: for p_k, p_v in parameter['data_parameter'].items(): p_v[1].update({'complex_mode': self._complex[model_name]}) parameter['data_parameter'][p_k] = p_v[0], p_v[1] for pname, param_value in zip(func.params, parameter['global_parameter']): if param_value is not None: param_value.name = pname func.set_global_parameter(param_value) parameter['func'] = func parameter['order'] = order parameter['len'] = param_len parameter['complex'] = self._complex[model_name] func_dict[model_name] = parameter replaceable = [] for model_name, v in func_dict.items(): for i, link_i in enumerate(v['links']): if link_i is None: continue rep_model, rep_func, rep_pos = link_i try: f = func_dict[rep_model] except KeyError: QtWidgets.QMessageBox().warning(self, 'Invalid value', 'Parameter cannot be linked: Model is unused', QtWidgets.QMessageBox.Ok) return try: f_idx = f['order'].index(rep_func) except ValueError: QtWidgets.QMessageBox().warning(self, 'Invalid value', 'Parameter cannot be linked: ' 'Function is probably not checked or deleted', QtWidgets.QMessageBox.Ok) return repl_idx = sum(f['len'][:f_idx])+rep_pos if repl_idx not in f['glob']['idx']: _ = QtWidgets.QMessageBox().warning(self, 'Invalid value', 'Parameter cannot be linked: ' 'Destination is not a global parameter.', QtWidgets.QMessageBox.Ok) return replaceable.append((model_name, i, rep_model, repl_idx)) replace_value = None for p_k in f['parameter'].values(): replace_value = p_k[0][repl_idx] break if replace_value is not None: for p_k in v['parameter'].values(): p_k[0][i] = replace_value weight = ['None', 'y', 'y2', 'Deltay'][self.weight_combobox.currentIndex()] fit_args = {'we': weight} if func_dict: self.fitStartSig.emit(func_dict, replaceable, fit_args) return func_dict @QtCore.pyqtSlot(int, name='on_preview_checkbox_stateChanged') def show_preview(self, state: int): if state: self.preview_button.show() self.preview_checkbox.setText('') self._prepare_preview() else: self.preview_emit.emit({}, -1, False) self.preview_lines = [] self.preview_button.hide() self.preview_checkbox.setText('Preview') @QtCore.pyqtSlot(name='on_preview_button_clicked') def _prepare_preview(self): self.get_functions() default_model = self.default_combobox.currentData() data = self.data_table.collect_data(default=default_model) func_dict = {} for k, mod in self.models.items(): func, order, param_len, _ = ModelFactory.create_from_list(mod) multiple_funcs = isinstance(func, MultiModel) if k in data: parameter, _ = self._prepare(mod, function_use=data[k], add_idx=multiple_funcs) parameter['func'] = func parameter['order'] = order parameter['len'] = param_len func_dict[k] = parameter for v in func_dict.values(): for i, link_i in enumerate(v['links']): if link_i is None: continue rep_model, rep_func, rep_pos = link_i f = func_dict[rep_model] f_idx = f['order'].index(rep_func) repl_idx = sum(f['len'][:f_idx]) + rep_pos replace_value = None for p_k in f['parameter'].values(): replace_value = p_k[0][repl_idx] break if replace_value is not None: for p_k in v['parameter'].values(): p_k[0][i] = replace_value self.preview_emit.emit(func_dict, QFitDialog.preview_num, True) def make_previews(self, x, models_parameters: dict): self.preview_lines = [] # needed to create namespace param_dict = Parameters() cnt = 0 for model in models_parameters.values(): f = model['func'] for parameter_list in model['data_parameter'].values(): for i, p_value in enumerate(parameter_list[0]): p_value.name = f.params[i] param_dict.add_parameter(f'a{cnt}', p_value) cnt += 1 for k, model in models_parameters.items(): f = model['func'] is_complex = self._complex[k] parameters = model['data_parameter'] color = model['color'] for p, kwargs in parameters.values(): p_value = [pp.value for pp in p] if is_complex is not None: y = f.func(x, *p_value, complex_mode=is_complex, **kwargs) if np.iscomplexobj(y): self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) else: self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) else: y = f.func(x, *p_value, **kwargs) self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) if isinstance(f, MultiModel): sub_kwargs = kwargs.copy() if is_complex is not None: sub_kwargs.update({'complex_mode': is_complex}) for i, s in enumerate(f.subs(x, *p_value, **sub_kwargs)): pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i])) if np.iscomplexobj(s): self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i)) self.preview_lines.append(PlotItem(x=x, y=s.imag, pen=pen_i)) else: self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) param_dict.clear() return self.preview_lines def set_parameter(self, parameter: dict[str, FitResult]): # which data uses which model data = self.data_table.collect_data(default=self.default_combobox.currentData()) for fitted_model, fitted_data in data.items(): glob_fit_parameter = [] for fit_id, fit_curve in parameter.items(): if fit_id in fitted_data: fit_parameter = list(fit_curve.parameter.values()) glob_fit_parameter.append(fit_parameter) self.set_parameter_iter(fit_id, [p.value for p in fit_parameter], self.models[fitted_model]) mean_parameter = [reduce(add, p, 0)/len(p) for p in zip(*glob_fit_parameter)] self.set_parameter_iter(None, mean_parameter, self.models[fitted_model]) def set_parameter_iter(self, fit_id: str | None, param: list[float], functions: list, cnt: int = 0): for model_p in functions: if model_p['active']: cnt += self.param_widgets[model_p['cnt']].set_parameter(fit_id, param[cnt:]) if model_p['children']: cnt += self.set_parameter_iter(fit_id, param, model_p['children'], cnt=cnt) return cnt def closeEvent(self, evt: QtGui.QCloseEvent): self.preview_emit.emit({}, -1, False) self.preview_lines = [] super().closeEvent(evt)