From e19d32b7362e548d37c555d143bad5afc01e179b Mon Sep 17 00:00:00 2001 From: dominik Date: Wed, 30 Mar 2022 17:27:02 +0200 Subject: [PATCH] complex fit and stuff --- nmreval/fit/_meta.py | 43 ++++++++++++++++----- nmreval/fit/minimizer.py | 2 +- nmreval/fit/model.py | 39 ++++++------------- nmreval/fit/result.py | 12 +++--- nmreval/gui_qt/fit/fitwindow.py | 38 +++++++++++-------- nmreval/gui_qt/fit/result.py | 37 +++++++++++++----- nmreval/gui_qt/main/management.py | 13 +++---- nmreval/models/bds.py | 63 ++++++++++++++++++++++--------- 8 files changed, 155 insertions(+), 92 deletions(-) diff --git a/nmreval/fit/_meta.py b/nmreval/fit/_meta.py index ad8ad9e..8c8c0ab 100644 --- a/nmreval/fit/_meta.py +++ b/nmreval/fit/_meta.py @@ -65,7 +65,7 @@ class MultiModel: self.bounds = [] self._kwargs_right = {} self._kwargs_left = {} - self._fun_kwargs = {} + self.fun_kwargs = {} # mapping kwargs to kwargs of underlying functions self._ext_int_kw = {} @@ -92,7 +92,7 @@ class MultiModel: if isinstance(func, MultiModel): strcnt = '' kw_dict.update(func.fun_kwargs) - self._fun_kwargs.update({k: v for k, v in kw_dict.items()}) + self.fun_kwargs.update({k: v for k, v in kw_dict.items()}) self._ext_int_kw.update({k: k for k in kw_dict.keys()}) else: @@ -102,7 +102,7 @@ class MultiModel: for k, v in temp_dic.items(): key_ = f'{k}_{idx}' kw_dict[key_] = v - self._fun_kwargs[key_] = v + self.fun_kwargs[key_] = v self._ext_int_kw[key_] = k strcnt = f'({idx})' @@ -116,13 +116,27 @@ class MultiModel: self.bounds.extend([(None, None)]*len(func.params)) def _left_arguments(self, *args, **kwargs): - kw_left = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_left} + kw_left = {} + for k_ext, k_int in self._ext_int_kw.items(): + if k_ext in self._kwargs_left: + if not k_ext.startswith('complex_mode'): + kw_left[k_int] = kwargs[k_ext] + else: + kw_left['complex_mode'] = kwargs['complex_mode'] + pl = args[:self._param_left] return pl, kw_left def _right_arguments(self, *args, **kwargs): - kw_right = {k_int: kwargs[k_ext] for k_ext, k_int in self._ext_int_kw.items() if k_ext in self._kwargs_right} + kw_right = {} + for k_ext, k_int in self._ext_int_kw.items(): + if k_ext in self._kwargs_right: + if not k_ext.startswith('complex_mode'): + kw_right[k_int] = kwargs[k_ext] + else: + kw_right['complex_mode'] = kwargs['complex_mode'] + pr = args[self._param_left:self._param_len] return pr, kw_right @@ -142,10 +156,6 @@ class MultiModel: def right_func(self, x, *args, **kwargs): return self._right.func(x, *args, **kwargs) - @property - def fun_kwargs(self): - return self._fun_kwargs - def subs(self, x, *args, **kwargs): """ Iterator over all sub-functions (depth-first and left-to-right) """ pl, kw_left = self._left_arguments(*args, **kwargs) @@ -159,3 +169,18 @@ class MultiModel: yield from self._right.subs(x, *pr, **kw_right) else: yield self._right.func(x, *pr, **kw_right) + + def sub_name(self): + if isinstance(self._left, MultiModel): + yield from self._left.sub_name() + elif hasattr(self._left, 'name'): + yield self._left.name + else: + yield self.name + '(lhs)' + + if isinstance(self._right, MultiModel): + yield from self._right.sub_name() + elif hasattr(self._right, 'name'): + yield self._right.name + else: + yield self.name + '(rhs)' diff --git a/nmreval/fit/minimizer.py b/nmreval/fit/minimizer.py index 4d5dc05..f1dfb4d 100644 --- a/nmreval/fit/minimizer.py +++ b/nmreval/fit/minimizer.py @@ -61,7 +61,7 @@ class FitRoutine(object): self.result.pop(idx) except ValueError: - raise IndexError('Data {} not found'.format(data)) + raise IndexError(f'Data {data} not found') def set_model(self, func, *args, idx=None, **kwargs): if isinstance(func, Model): diff --git a/nmreval/fit/model.py b/nmreval/fit/model.py index f815f03..631c715 100644 --- a/nmreval/fit/model.py +++ b/nmreval/fit/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from typing import Sized @@ -61,11 +63,11 @@ class Model(object): self._int_func = model.func if hasattr(model, 'subs'): self._int_iter = model.subs + self._iter_name = model.sub_name self.is_multi = True else: self._int_iter = model.func - try: self.lb, self.ub = list(zip(*model.bounds)) except AttributeError: @@ -78,16 +80,6 @@ class Model(object): self.fun_kwargs = {k: v.default for k, v in inspect.signature(model.func).parameters.items() if v.default is not inspect.Parameter.empty} - def set_complex(self, state): - if state not in [None, 'complex', 'real', 'imag']: - raise ValueError('"complex" argument is not None, "complex", "real", "imag"') - - self.is_complex = state - if state in ['real', 'imag']: - self._complex_part = state - else: - self._complex_part = False - def set_global_parameter(self, idx, p, var=None, lb=None, ub=None, default_bounds=False): if idx is None: self.parameter = Parameters() @@ -127,26 +119,19 @@ class Model(object): f = self._int_func(x, *p, *self.fun_args, **kwargs) - if self._complex_part: - if self._complex_part == 'real': - return f.real - else: - return f.imag - return f def sub(self, p, x, **kwargs): + if not kwargs: + kwargs = self.fun_kwargs + if not self.is_multi: return [self.func(p, x, **kwargs)] - else: - if not kwargs: - kwargs = self.fun_kwargs - - if self._complex_part: - if self._complex_part == 'real': - return [f.real for f in self._int_iter(x, *p, *self.fun_args, **kwargs)] - else: - return [f.imag for f in self._int_iter(x, *p, *self.fun_args, **kwargs)] - return list(self._int_iter(x, *p, *self.fun_args, **kwargs)) + + def sub_name(self): + if not self.is_multi: + return [self.name] + else: + return list(self._iter_name()) diff --git a/nmreval/fit/result.py b/nmreval/fit/result.py index 8c0a803..8a5a589 100644 --- a/nmreval/fit/result.py +++ b/nmreval/fit/result.py @@ -62,11 +62,11 @@ class FitResultCreator: part_functions = [] if model.is_multi: - for sub_y in model.sub(p_final, _x, **fun_kwargs): + for sub_name, sub_y in zip(model.sub_name(), model.sub(p_final, _x, **fun_kwargs)): if np.iscomplexobj(sub_y): - part_functions.append(Signal(_x, sub_y)) + part_functions.append(Signal(_x, sub_y, name=sub_name)) else: - part_functions.append(Points(_x, sub_y)) + part_functions.append(Points(_x, sub_y, name=sub_name)) _y = model.func(p_final, _x, **fun_kwargs) resid = model.func(p_final, x_orig, **fun_kwargs) - y_orig @@ -98,7 +98,7 @@ class FitResultCreator: FitResult(_x, _y, x_orig, y_orig, parameters, fun_kwargs, resid, nobs, nvar, model.name, stats, idx=idx, corr=correlation, pcorr=partial_correlation, - islog=islog, iscomplex=model.is_complex), + islog=islog), part_functions, ) @@ -139,7 +139,7 @@ class FitResultCreator: class FitResult(Points): def __init__(self, x, y, x_data, y_data, params, fun_kwargs, resid, nobs, nvar, name, stats, - idx=None, corr=None, pcorr=None, islog=False, iscomplex=None, + idx=None, corr=None, pcorr=None, islog=False, **kwargs): self.parameter, name = self._prepare_names(params, name) @@ -155,7 +155,7 @@ class FitResult(Points): self.correlation = corr self.partial_correlation = pcorr self.islog = islog - self.iscomplex = iscomplex + self.iscomplex = np.iscomplexobj(self.y) self.x_data = x_data self.y_data = y_data self._model_name = name diff --git a/nmreval/gui_qt/fit/fitwindow.py b/nmreval/gui_qt/fit/fitwindow.py index cff9c8e..445686b 100644 --- a/nmreval/gui_qt/fit/fitwindow.py +++ b/nmreval/gui_qt/fit/fitwindow.py @@ -6,6 +6,7 @@ from operator import add from string import ascii_letters from typing import Dict, List, Tuple +import numpy as np from pyqtgraph import mkPen from .fit_forms import FitTableWidget @@ -284,10 +285,11 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): parameter['func'] = func parameter['order'] = order parameter['len'] = param_len - if self._complex[k] is None: - parameter['complex'] = self._complex[k] - else: - parameter['complex'] = ['complex', 'real', 'imag'][self._complex[k]] + parameter['complex'] = self._complex[k] + if self._complex[k] is not None: + for p_k, p_v in parameter['parameter'].items(): + p_v[1].update({'complex_mode': self._complex[k]}) + parameter['parameter'][p_k] = p_v[0], p_v[1] func_dict[k] = parameter @@ -409,23 +411,29 @@ class QFitDialog(QtWidgets.QWidget, Ui_FitDialog): color = model['color'] for p, kwargs in parameters.values(): - y = f.func(x, *p, **kwargs) - if is_complex is None: + if is_complex is not None: + y = f.func(x, *p, complex_mode=is_complex, **kwargs) + if np.iscomplexobj(y): + self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) + self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) + else: + self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) + else: + y = f.func(x, *p, **kwargs) self.preview_lines.append(PlotItem(x=x, y=y, pen=mkPen(width=3))) - if is_complex in [0, 1]: - self.preview_lines.append(PlotItem(x=x, y=y.real, pen=mkPen(width=3))) - if is_complex in [0, 2]: - self.preview_lines.append(PlotItem(x=x, y=y.imag, pen=mkPen(width=3))) if isinstance(f, MultiModel): - for i, s in enumerate(f.subs(x, *p, **kwargs)): + sub_kwargs = kwargs.copy() + if is_complex is not None: + sub_kwargs.update({'complex_mode': is_complex}) + + for i, s in enumerate(f.subs(x, *p, **sub_kwargs)): pen_i = mkPen(QtGui.QColor.fromRgbF(*color[i])) - if is_complex is None: - self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) - if is_complex in [0, 1]: + if np.iscomplexobj(s): self.preview_lines.append(PlotItem(x=x, y=s.real, pen=pen_i)) - if is_complex in [0, 2]: self.preview_lines.append(PlotItem(x=x, y=s.imag, pen=pen_i)) + else: + self.preview_lines.append(PlotItem(x=x, y=s, pen=pen_i)) return self.preview_lines diff --git a/nmreval/gui_qt/fit/result.py b/nmreval/gui_qt/fit/result.py index b4ff9e0..3065742 100644 --- a/nmreval/gui_qt/fit/result.py +++ b/nmreval/gui_qt/fit/result.py @@ -41,17 +41,31 @@ class QFitResult(QtWidgets.QDialog, Ui_Dialog): self._opts = [(False, False) for _ in range(len(self._results))] self.residplot = self.graphicsView.addPlot(row=0, col=0) - self.resid_graph = PlotItem(x=[], y=[], symbol='o', symbolPen=None, symbolBrush=mkBrush(color='r'), pen=None) + self.resid_graph = PlotItem(x=[], y=[], + symbol='o', symbolPen=None, symbolBrush=mkBrush(color=(174, 199, 232)), + pen=None) + self.resid_graph_imag = PlotItem(x=[], y=[], + symbol='s', symbolPen=None, symbolBrush=mkBrush(color=(255, 127, 14)), + pen=None) self.residplot.addItem(self.resid_graph) + self.residplot.addItem(self.resid_graph_imag) self.residplot.setLabel('left', 'Residual') self.fitplot = self.graphicsView.addPlot(row=1, col=0) - self.data_graph = PlotItem(x=[], y=[], symbol='o', symbolPen=None, symbolBrush=mkBrush(color='r'), pen=None) + self.data_graph = PlotItem(x=[], y=[], + symbol='o', symbolPen=None, symbolBrush=mkBrush(color=(174, 199, 232)), + pen=None) + self.data_graph_imag = PlotItem(x=[], y=[], + symbol='s', symbolPen=None, symbolBrush=mkBrush(color=(255, 127, 14)), + pen=None) self.fitplot.addItem(self.data_graph) + self.fitplot.addItem(self.data_graph_imag) self.fitplot.setLabel('left', 'Function') self.fit_graph = PlotItem(x=[], y=[]) + self.fit_graph_imag = PlotItem(x=[], y=[]) self.fitplot.addItem(self.fit_graph) + self.fitplot.addItem(self.fit_graph_imag) self.cmap = RdBuCMap(vmin=-1, vmax=1) @@ -138,15 +152,20 @@ class QFitResult(QtWidgets.QDialog, Ui_Dialog): res = self._results[idx] iscomplex = res.iscomplex - self.resid_graph.setData(x=res.x_data, y=res.residual) - if iscomplex == 'complex': - self.data_graph.setData(x=r_[res.x_data, res.x_data], - y=r_[res.y_data.real, res.y_data.imag]) - self.fit_graph.setData(x=r_[res.x, res.x], - y=r_[res.y.real, res.y.imag]) + if iscomplex: + self.data_graph.setData(x=res.x_data, y=res.y_data.real) + self.data_graph_imag.setData(x=res.x_data, y=res.y_data.imag) + self.fit_graph.setData(x=res.x, y=res.y.real) + self.fit_graph_imag.setData(x=res.x, y=res.y.imag) + self.resid_graph.setData(x=res.x_data, y=res.residual.real) + self.resid_graph_imag.setData(x=res.x_data, y=res.residual.imag) else: + self.resid_graph.setData(x=res.x_data, y=res.residual) + self.resid_graph_imag.setData(x=[], y=[]) self.data_graph.setData(x=res.x_data, y=res.y_data) + self.data_graph_imag.setData(x=[], y=[]) self.fit_graph.setData(x=res.x, y=res.y) + self.fit_graph_imag.setData(x=[], y=[]) self.fitplot.setLogMode(x=res.islog) self.residplot.setLogMode(x=res.islog) @@ -243,7 +262,7 @@ class QFitResult(QtWidgets.QDialog, Ui_Dialog): self.redoFit.emit(self._results) elif button_type == self.buttonBox.Ok: - graph = None + graph = '-1' if self.parameter_checkbox.isChecked(): if self.graph_checkBox.checkState() == QtCore.Qt.Checked: graph = '' diff --git a/nmreval/gui_qt/main/management.py b/nmreval/gui_qt/main/management.py index d10200b..2b42314 100644 --- a/nmreval/gui_qt/main/management.py +++ b/nmreval/gui_qt/main/management.py @@ -390,16 +390,15 @@ class UpperManagement(QtCore.QObject): models[model_id] = m m_complex = model_p['complex'] - m.set_complex(m_complex) for set_id, set_params in model_p['parameter'].items(): data_i = self.data[set_id] if we == 'Deltay': we = data_i.y_err**2 - if m_complex is None or m_complex == 'real': + if m_complex is None or m_complex == 1: _y = data_i.y.real - elif m_complex == 'imag' and np.iscomplexobj(self.data[set_id].y): + elif m_complex == 2 and np.iscomplexobj(self.data[set_id].y): _y = data_i.y.imag else: _y = data_i.y @@ -518,14 +517,15 @@ class UpperManagement(QtCore.QObject): if k in parts and show_fit: for subfunc, col in zip(parts[k], TUColorsC): - sub_f_id = self.add(subfunc, color=col, linestyle=LineStyle.Dashed, symbol=SymbolStyle.No) subfunc.value = data_k.value subfunc.group = data_k.group + sub_f_id = self.add(subfunc, color=col, linestyle=LineStyle.Dashed, symbol=SymbolStyle.No) + f_id_list.append(sub_f_id) self.delete_sets(tobedeleted) - if accepted and param_graph is not None: + if accepted and (param_graph != '-1'): self.make_fit_parameter(accepted, graph_id=param_graph) self.newData.emit(f_id_list, gid) @@ -698,7 +698,7 @@ class UpperManagement(QtCore.QObject): @QtCore.pyqtSlot() def update_color(self): - UpperManagement._colors = cycle(Colors) + UpperManagement._colors = cycle(TUColors) for i in self.active: self.data[i].color = next(UpperManagement._colors) @@ -1103,4 +1103,3 @@ class FitWorker(QtCore.QObject): res = [e.args] success = False self.finished.emit(res, success) - diff --git a/nmreval/models/bds.py b/nmreval/models/bds.py index 0aa2359..1bc25ea 100644 --- a/nmreval/models/bds.py +++ b/nmreval/models/bds.py @@ -1,5 +1,3 @@ -from typing import List, Optional, Tuple - import numpy as np from ..distributions import Debye, ColeCole, ColeDavidson, KWW, HavriliakNegami @@ -16,12 +14,18 @@ class _AbstractBDS: iscomplex = True @classmethod - def func(cls, x, *args, **kwargs): + def func(cls, x, *args, complex_mode: int = 0, **kwargs): # args[0] : Delta epsilon # args[1:] : every other parameter - chi = args[0] * cls.susceptibility(2*np.pi*x, *args[1:]) - - return chi + chi = args[0] * cls.susceptibility(2*np.pi*x, *args[1:], **kwargs) + if complex_mode == 0: + return chi + elif complex_mode == 1: + return chi.real + elif complex_mode == 2: + return chi.imag + else: + raise ValueError(f'{complex_mode!r} is not 0, 1, 2') class DebyeBDS(_AbstractBDS): @@ -70,9 +74,16 @@ class EpsInfty: iscomplex = True @staticmethod - def func(x, eps): - ret_val = np.zeros(x.shape, dtype=complex) - ret_val += eps + def func(x, eps, complex_mode: int = 0): + if complex_mode == 0: + ret_val = np.zeros(x.shape, dtype=complex) + ret_val += eps + elif complex_mode == 1: + ret_val = eps * np.ones(x.shape) + elif complex_mode == 2: + ret_val = np.zeros(x.shape) + else: + raise ValueError(f'{complex_mode!r} is not 0, 1, 2') return ret_val @@ -86,8 +97,17 @@ class PowerLawBDS: iscomplex = True @staticmethod - def func(x, a, n): - return a / (1j*x)**n + def func(x, a, n, complex_mode: int = 0): + if complex_mode == 0: + ret_val = np.exp(1j*n*np.pi/2) * a / x**n + elif complex_mode == 1: + ret_val = np.cos(n*np.pi/2) * a / x**n + elif complex_mode == 2: + ret_val = np.sin(n*np.pi/2) * a / x**n + else: + raise ValueError(f'{complex_mode!r} is not 0, 1, 2') + + return ret_val class DCCondBDS: @@ -99,14 +119,21 @@ class DCCondBDS: iscomplex = True @staticmethod - def func(x, sigma): - ret_val = np.zeros(x.shape, dtype=complex) - ret_val += 1j * sigma / x / epsilon0 + def func(x, sigma, complex_mode: int = 0): + if complex_mode == 0: + ret_val = np.zeros(x.shape, dtype=complex) + ret_val += 1j * sigma / x / epsilon0 + elif complex_mode == 1: + ret_val = np.zeros(x.shape) + elif complex_mode == 2: + ret_val = sigma / x / epsilon0 + else: + raise ValueError(f'{complex_mode!r} is not 0, 1, 2') return ret_val -class HavriliakNegamiDerivative: +class DerivativeHavriliakNegami: name = 'Derivative HN' type = 'Dielectric Spectroscopy' params = [r'\Delta\epsilon', r'\tau', r'\alpha', r'\gamma'] @@ -123,7 +150,7 @@ class HavriliakNegamiDerivative: return eps * np.pi * numer / denom / 2. -class ColeColeDerivative: +class DerivativeColeCole: name = 'Derivative CC' type = 'Dielectric Spectroscopy' params = [r'\Delta\epsilon', r'\tau', r'\alpha'] @@ -140,7 +167,7 @@ class ColeColeDerivative: return eps * np.pi * numer / denom / 2. -class ColeDavidsonDerivative: +class DerivativeColeDavidson: name = 'Derivative CD' type = 'Dielectric Spectroscopy' params = [r'\Delta\epsilon', r'\tau', r'\gamma'] @@ -149,7 +176,7 @@ class ColeDavidsonDerivative: @staticmethod def func(x, eps, tau, g): omtau = 2*np.pi*x * tau - numer = g * omtau * np.sin((1+g)*np.sin(omtau)) + numer = g * omtau * np.sin((1+g)*np.arctan(omtau)) denom = (1 + omtau**2)**((1+g)/2.) return eps * np.pi * numer / denom / 2.