From 575cb5e8f6a7d6627371ce63e3153ee5654b3781 Mon Sep 17 00:00:00 2001 From: Dominik Demuth Date: Sun, 21 Jan 2024 17:01:46 +0000 Subject: [PATCH] 209-fit-tree (#222) closes #209 Co-authored-by: Dominik Demuth Reviewed-on: https://gitea.pkm.physik.tu-darmstadt.de/IPKM/nmreval/pulls/222 --- src/gui_qt/fit/fit_forms.py | 2 +- src/gui_qt/fit/fitwindow.py | 4 +-- src/nmreval/fit/_meta.py | 55 +++++++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/gui_qt/fit/fit_forms.py b/src/gui_qt/fit/fit_forms.py index 60c3191..2a9c4a2 100644 --- a/src/gui_qt/fit/fit_forms.py +++ b/src/gui_qt/fit/fit_forms.py @@ -70,7 +70,7 @@ class FitModelTree(QtWidgets.QTreeWidget): self.remove_function(item) elif evt.key() == QtCore.Qt.Key.Key_Space: - for item in self.treeWidget.selectedItems(): + for item in self.selectedItems(): cs = item.checkState(0) if cs == QtCore.Qt.CheckState.Unchecked: item.setCheckState(0, QtCore.Qt.CheckState.Checked) diff --git a/src/gui_qt/fit/fitwindow.py b/src/gui_qt/fit/fitwindow.py index e71e2be..a5a4fd0 100644 --- a/src/gui_qt/fit/fitwindow.py +++ b/src/gui_qt/fit/fitwindow.py @@ -275,7 +275,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): func_dict = {} for model_name, model_parameter in self.models.items(): - func, order, param_len = ModelFactory.create_from_list(model_parameter) + func, order, param_len, _ = ModelFactory.create_from_list(model_parameter) multiple_funcs = isinstance(func, MultiModel) if func is None: continue @@ -387,7 +387,7 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): func_dict = {} for k, mod in self.models.items(): - func, order, param_len = ModelFactory.create_from_list(mod) + func, order, param_len, _ = ModelFactory.create_from_list(mod) multiple_funcs = isinstance(func, MultiModel) if k in data: diff --git a/src/nmreval/fit/_meta.py b/src/nmreval/fit/_meta.py index 889cc57..ea33d06 100644 --- a/src/nmreval/fit/_meta.py +++ b/src/nmreval/fit/_meta.py @@ -9,7 +9,13 @@ from inspect import signature, Parameter class ModelFactory: @staticmethod - def create_from_list(funcs: list, left=None, func_order=None, param_len=None, left_cnt=None): + def create_from_list( + funcs: list, + left=None, + func_order: list[int] = None, + param_len: list[int] = None, + left_cnt: int = 0, + ): if func_order is None: func_order = [] @@ -20,32 +26,50 @@ class ModelFactory: if not func['active']: continue - func_order.append(func['cnt']) - param_len.append(len(func['func'].params)) - if func['children']: - right, _, _ = ModelFactory.create_from_list(func['children'], left_cnt=func['pos'], - func_order=func_order, param_len=param_len) - right_cnt = None - right = MultiModel(func['func'], right, func['children'][0]['op'], left_idx=func['cnt'], right_idx=None) + f = func.copy() + f['children'] = [] + right, _, _, right_cnt = ModelFactory.create_from_list( + [f] + func['children'], + left_cnt=func['pos'], + func_order=func_order, + param_len=param_len, + ) else: right = func['func'] right_cnt = func['cnt'] + func_order.append(func['cnt']) + param_len.append(len(func['func'].params)) + if left is None: left = right left_cnt = right_cnt else: - left = MultiModel(left, right, func['op'], - left_idx=left_cnt, right_idx=right_cnt) + left = MultiModel(left, right, func['op'], left_idx=left_cnt, right_idx=right_cnt) - return left, func_order, param_len + return left, func_order, param_len, left_cnt class MultiModel: - op_repr = {operator.add: ' + ', operator.mul: ' * ', operator.sub: ' - ', operator.truediv: ' / '} - str_op = {'+': operator.add, '*': operator.mul, '-': operator.sub, '/': operator.truediv} - int_op = {0: operator.add, 1: operator.mul, 2: operator.sub, 3: operator.truediv} + op_repr = { + operator.add: ' + ', + operator.mul: ' * ', + operator.sub: ' - ', + operator.truediv: ' / ', + } + str_op = { + '+': operator.add, + '*': operator.mul, + '-': operator.sub, + '/': operator.truediv, + } + int_op = { + 0: operator.add, + 1: operator.mul, + 2: operator.sub, + 3: operator.truediv, + } def __init__(self, left: Any, @@ -69,6 +93,9 @@ class MultiModel: if self._op is None: raise ValueError('Invalid binary operator.') + if right_idx is None: + right_idx = left_idx + 1 + self.name = '(' self.params = [] self.bounds = []