from __future__ import annotations from collections import OrderedDict from itertools import cycle from typing import Any import numpy as np from pyqtgraph import mkPen from nmreval.data.points import Points from nmreval.data.signals import Signal from nmreval.utils.text import convert from nmreval.data.bds import BDS from nmreval.data.dsc import DSC from nmreval.lib.colors import BaseColor, TUColors from nmreval.lib.lines import LineStyle from nmreval.lib.symbols import SymbolStyle, symbolcycle from nmreval.data.nmr import Spectrum, FID from ..Qt import QtCore, QtGui from ..io.exporters import GraceExporter from ..lib.decorators import plot_update from ..lib.pg_objects import ErrorBars, PlotItem class ExperimentContainer(QtCore.QObject): dataChanged = QtCore.pyqtSignal(str) labelChanged = QtCore.pyqtSignal(str, str) groupChanged = QtCore.pyqtSignal(str, str) colors = cycle(TUColors) def __init__(self, identifier, data, **kwargs): super().__init__() self.id = str(identifier) self._fits = [] self._data = data self._manager = kwargs.get('manager') self.graph = '' self.mode = 'point' self.plot_real = None self.plot_imag = None self.plot_error = None self.actions = {} self._update_actions() @plot_update def _init_plot(self): raise NotImplementedError def __getitem__(self, item): try: return self._data[item] except KeyError: raise KeyError('Unknown key %s' % str(item)) from None def __del__(self): del self._data del self.plot_real del self.plot_imag del self.plot_error def __repr__(self): return 'name:' + self.name def __len__(self): return len(self._data) def copy(self, full: bool = False, keep_color: bool = True): if full: pen_dict = {} if keep_color: pen_dict = { 'symbol': self.plot_real.symbol, 'symbolcolor': self.plot_real.symbolcolor, 'symbolsize': self.plot_real.symbolsize, 'linestyle': self.plot_real.linestyle, 'linecolor': self.plot_real.linecolor, 'linewidth': self.plot_real.linewidth, } new_data = type(self)(str(self.id), self._data.copy(), manager=self._manager, **pen_dict) new_data.mode = self.mode if keep_color and self.plot_imag is not None: new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize, color=self.plot_imag.symbolcolor) new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth, color=self.plot_imag.linecolor) return new_data else: return self._data.copy() def change_type(self, data): if isinstance(data, (FID, Spectrum, BDS)): new_type = SignalContainer elif isinstance(data, Points): new_type = PointContainer else: raise TypeError('Unknown data type') # pen_dict = { # 'symbol': self.plot_real.symbol, # 'symbolcolor': self.plot_real.symbolcolor, # 'symbolsize': self.plot_real.symbolsize, # 'linestyle': self.plot_real.linestyle, # 'linecolor': self.plot_real.linecolor, # 'linewidth': self.plot_real.linewidth, # } new_data = new_type(str(self.id), data, manager=self._manager) # if new_data.plot_imag is not None: # if self.plot_imag is not None: # new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize, # color=self.plot_imag.symbolcolor) # new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth, # color=self.plot_imag.linecolor) # else: # new_data.plot_imag.set_symbol(symbol=self.plot_real.symbol, size=self.plot_real.symbolsize, # color=self.plot_real.symbolcolor) # new_data.plot_imag.set_line(style=self.plot_real.linestyle, width=self.plot_real.linewidth, # color=self.plot_real.linecolor) return new_data @property def x(self): return self._data.x[self._data.mask] @x.setter @plot_update def x(self, value): if len(self._data.x) == len(value): self._data.x = value elif len(self._data.x[self._data.mask]) == len(value): self._data.x = value self._data.y = self._data.y[self._data.mask] self._data.mask = np.ma.array(np.ones_like(self._data.x, dtype=bool)) else: raise ValueError('x and y have different dimensions!') @property def y(self): return self._data.y[self._data.mask] @y.setter @plot_update def y(self, value): if len(self._data.y) == len(value): self._data.y = value elif len(self._data.y[self._data.mask]) == len(value): self._data.y = value self._data.x = self._data.x[self._data.mask] self._data.mask = np.ma.array(np.ones_like(self._data.y, dtype=bool)) else: raise ValueError('x and y have different dimensions!') @property def y_err(self): return self._data.y_err[self._data.mask] @y_err.setter @plot_update def y_err(self, value): if len(self._data.y_err) == len(value): self._data.y_err = value elif len(self._data.y[self._data.mask]) == len(value): self._data.y_err[self._data.mask] = value else: raise ValueError('y_err has not correct length') @property def name(self): return self._data.name @name.setter @plot_update def name(self, value: str): self._data.name = value self.plot_real.opts['name'] = value try: self.plot_imag.opts['name'] = value except AttributeError: pass try: num_val = float(value) self._data.value = num_val except ValueError: pass @property def value(self): return self._data.value @value.setter def value(self, val): self._data.value = float(val) @property def group(self): return str(self._data['group']) @group.setter def group(self, valium): self._data['group'] = str(valium) self.groupChanged.emit(self.id, str(valium)) @property def data(self): return self._data @data.setter @plot_update def data(self, new_data): self._data = new_data self._update_actions() @property def opts(self): return self._data.meta @property def plots(self): return self.plot_real, self.plot_imag, self.plot_error def get_state(self): ret_dic = { 'id': self.id, 'data': self._data.get_state(), 'mode': self.mode, 'fits': self._fits, 'real': ({'symbol': self.plot_real.symbol.value, 'size': self.plot_real.symbolsize, 'color': self.plot_real.symbolcolor.value}, {'style': self.plot_real.linestyle.value, 'width': self.plot_real.linewidth, 'color': self.plot_real.linecolor.value}) } if self.plot_imag is not None: ret_dic['imag'] = ({'symbol': self.plot_imag.symbol.value, 'size': self.plot_imag.symbolsize, 'color': self.plot_imag.symbolcolor.value}, {'style': self.plot_imag.linestyle.value, 'width': self.plot_imag.linewidth, 'color': self.plot_imag.linecolor.value}) return ret_dic def get_fits(self): return [self._manager[idx] for idx in self._fits] def has_fits(self): return len(self._fits) != 0 def set_fits(self, value: str | list, replace: bool = False): if isinstance(value, str): value = [value] if replace: if isinstance(value, list): self._fits = value else: raise TypeError() else: self._fits.extend(value) def _update_actions(self): self.actions.update({'sort': self._data.sort, 'cut': self._data.cut, 'norm': self._data.normalize, 'center': self.center}) @plot_update def update(self, opts: dict): self._data.update(opts) def get_properties(self) -> dict: props = OrderedDict() props['General'] = OrderedDict([('Name', self.name), ('Value', str(self.value)), ('Group', str(self.group))]) props['Symbol'] = OrderedDict() props['Line'] = OrderedDict() props['Symbol']['Symbol'] = self.plot_real.symbol props['Symbol']['Size'] = self.plot_real.symbolsize props['Symbol']['Color'] = self.plot_real.symbolcolor props['Line']['Style'] = self.plot_real.linestyle props['Line']['Width'] = self.plot_real.linewidth props['Line']['Color'] = self.plot_real.linecolor if self.plot_imag is not None: props['Symbol']['Symbol (imag)'] = self.plot_imag.symbol props['Symbol']['Size (imag)'] = self.plot_imag.symbolsize props['Symbol']['Color (imag)'] = self.plot_imag.symbolcolor props['Line']['Style (imag)'] = self.plot_imag.linestyle props['Line']['Width (imag)'] = self.plot_imag.linewidth props['Line']['Color (imag)'] = self.plot_imag.linecolor return props def setColor(self, color, symbol=False, line=False, mode='real'): if mode in ['real', 'all']: self.plot_real.set_color(color, symbol=symbol, line=line) if self.plot_error is not None: err_pen = self.plot_error.opts['pen'] err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb())) self.plot_error.setData(pen=err_pen) if mode in ['imag', 'all'] and self.plot_imag is not None: self.plot_imag.set_color(color, symbol=symbol, line=line) def setSymbol(self, *, symbol=None, color=None, size=None, mode='real'): if mode in ['real', 'all']: self.plot_real.set_symbol(symbol=symbol, size=size, color=color) if color is not None and self.plot_error is not None and self.plot_real.symbol != SymbolStyle.No: err_pen = self.plot_error.opts['pen'] err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb())) self.plot_error.setData(pen=err_pen) elif mode in ['imag', 'all'] and self.plot_imag is not None: self.plot_imag.set_symbol(symbol=symbol, size=size, color=color) else: print('Updating symbol failed for ' + str(self.id)) def setLine(self, *, width=None, style=None, color=None, mode='real'): if mode in ['real', 'all']: self.plot_real.set_line(width=width, style=style, color=color) if color is not None and self.plot_error is not None and self.plot_real.symbol == SymbolStyle.No: err_pen = self.plot_error.opts['pen'] err_pen.setColor(QtGui.QColor(*self.plot_real.linecolor.rgb())) self.plot_error.setData(pen=err_pen) elif mode in ['imag', 'all'] and self.plot_imag is not None: self.plot_imag.set_line(width=width, style=style, color=color) else: print('Updating line failed for ' + str(self.id)) def update_property(self, key1: str, key2: str, value: Any): keykey = key2.split() if len(keykey) == 1: if key1 == 'Symbol': self.setSymbol(mode='real', **{key2.lower(): value}) elif key1 == 'Line': self.setLine(mode='real', **{key2.lower(): value}) elif key1 == 'General': setattr(self, key2.lower(), value) else: if key1 == 'Symbol': self.setSymbol(mode='imag', **{keykey[0].lower(): value}) elif key1 == 'Line': self.setLine(mode='imag', **{keykey[0].lower(): value}) def points(self, params: dict): return self._data.points(**params) @plot_update def apply(self, func: str, args: tuple): if func in self.actions: f = self.actions[func] f(*args) return self @plot_update def unsort(self, order: np.ndarray): # this exists only to update plots after an undo action self._data.x = self._data.x[order] self._data.y = self._data.y[order] self._data.y_err = self._data.y_err[order] self._data.mask = self._data.mask[order] def save(self, fname): ext = fname.suffix if ext == '.agr': dic = self._manager.graphs[self.graph].get_state() dic['items'] = [self.plot_real.get_data_opts()] if self.plot_imag is not None: dic['items'].append(self.plot_imag.get_data_opts()) GraceExporter(dic).export(fname) elif ext in ['.dat', '.txt']: self._data.savetxt(fname, err=True) else: raise ValueError('Unknown extension ' + ext) @plot_update def setvalues(self, pos, valium): xy, position = pos if xy == 0: self._data.x[position] = valium elif xy == 1: self._data.y[position] = valium else: self._data.y_err[position] = valium @property def mask(self): return self._data.mask @mask.setter @plot_update def mask(self, m): self._data.mask = np.asarray(m, dtype=bool) @plot_update def add(self, m): if isinstance(m, (np.ndarray, list, tuple)): self._data.append(m[0], m[1], y_err=m[2]) elif isinstance(m, (Points, ExperimentContainer)): self._data.append(m.x, m.y, y_err=m.y_err) else: raise TypeError('Unknown type ' + type(m)) @plot_update def remove(self, m): self._data.remove(m) @plot_update def center(self) -> float: offset = self.x[np.argmax(self.y.real)] self._data._x -= offset return offset def get_namespace(self, i: int = None, j: int = None) -> dict: if (i is None) and (j is None): prefix = '' else: prefix = 'g[%i].s[%i].' % (i, j) namespace = {prefix + 'x': (self.x, 'x values'), prefix + 'y': [self.y, 'y values'], prefix + 'y_err': (self.y_err, 'y error values'), prefix + 'value': (self.value, str(self.value))} if len(self._fits) == 1: namespace.update({ "%sfit['%s']" % (prefix, convert(pname, old='tex', new='str')): (pvalue.value, str(pvalue.value)) for (pname, pvalue) in self._manager[self._fits[0]].parameter.items() }) else: for k, f in enumerate(self._fits): namespace.update({ "%sfit['%s_%d']" % (prefix, convert(pname, old='tex', new='str'), k): (pvalue.value, str(pvalue.value)) for (pname, pvalue) in self._manager[f].parameter.items() }) return namespace def eval_expression(self, cmds, namespace): namespace.update({'x': self._data.x, 'y': self._data.y, 'y_err': self._data.y_err, 'value': self.value}) if len(self._fits) == 1: namespace.update({"fit['%s']" % (convert(pname, old='tex', new='str')): pvalue.value for (pname, pvalue) in self._manager[self._fits[0]].parameter.items()}) else: for k, f in enumerate(self._fits): namespace.update({"fit['%s_%i']" % (convert(pname, old='tex', new='str'), k): pvalue.value for (pname, pvalue) in self._manager[f].parameter.items()}) new_data = self.copy() for c in cmds: if c: exec(c, globals(), namespace) new_data.set_data(x=namespace['x'], y=namespace['y'], y_err=namespace['y_err'], replace_mask=False) new_data.value = namespace['value'] return new_data def binning(self, digits: float): new_data = self.copy(full=True) new_data.data = self._data.binning(value=digits) return new_data class PointContainer(ExperimentContainer): symbols = symbolcycle() def __init__(self, identifier, data, **kwargs): super().__init__(identifier, data, **kwargs) self.mode = 'pts' self._init_plot(**kwargs) if isinstance(self._data, DSC): self.mode = 'dsc' def _init_plot(self, **kwargs): self.plot_imag = None color = kwargs.get('color', None) symcolor = kwargs.get('symbolcolor', color) linecolor = kwargs.get('linecolor', color) if symcolor is None and linecolor is None: color = next(self.colors) symcolor = color linecolor = color elif symcolor is None: symcolor = linecolor elif linecolor is None: linecolor = symcolor sym_kwargs = { 'symbol': kwargs.get('symbol', None), 'size': kwargs.get('symbolsize', 10), 'color': symcolor } line_kwargs = { 'style': kwargs.get('linestyle', None), 'width': kwargs.get('linewidth', 1), 'color': linecolor } if sym_kwargs['symbol'] is None and line_kwargs['style'] is None: if len(self._data) > 500: line_kwargs['style'] = LineStyle.Solid sym_kwargs['symbol'] = SymbolStyle.No else: line_kwargs['style'] = LineStyle.No sym_kwargs['symbol'] = next(PointContainer.symbols) self.plot_real = PlotItem(x=self.x, y=self.y, name=self.name, symbol=None, pen=None, connect='finite') self.setSymbol(mode='real', **sym_kwargs) self.setLine(mode='real', **line_kwargs) if sym_kwargs['symbol'] != SymbolStyle.No: self.plot_error = ErrorBars(x=self.x, y=self.y, top=self.y_err, bottom=self.y_err, pen=mkPen({'color': self.plot_real.symbolcolor.rgb()})) else: self.plot_error = ErrorBars(x=self.x, y=self.y, top=self.y_err, bottom=self.y_err, pen=mkPen({'color': self.plot_real.linecolor.rgb()})) class FitContainer(ExperimentContainer): def __init__(self, identifier, data, **kwargs): super().__init__(identifier, data, **kwargs) self.fitted_key = kwargs.get('src', '') self.mode = 'fit' self.parent_set = kwargs.get('src', '') self._init_plot(**kwargs) for n in ['statistics', 'nobs', 'nvar', 'parameter', 'model_name']: setattr(self, n, getattr(data, n)) def _init_plot(self, **kwargs): color = kwargs.get('color') if color is None: color = kwargs.get('linecolor', (0, 0, 0)) if isinstance(color, BaseColor): color = color.rgb() self.plot_real = PlotItem(x=self.x, y=self.y.real, name=self.name, pen=mkPen({'color': color}), connect='finite', symbol=None) if np.iscomplexobj(self._data.y): self.plot_imag = PlotItem(x=self.x, y=self.y.imag, name=self.name, pen=mkPen({'color': color}), connect='finite', symbol=None) @property def fitted_key(self): return self._data.idx @fitted_key.setter def fitted_key(self, val): self._data.idx = val def get_namespace(self, i: int = None, j: int = None): namespace = super().get_namespace(i, j) namespace.update({ "g[%i].s[%i].fit['%s']" % (i, j, convert(pname, old='latex', new='plain')): (pvalue.value, str(pvalue.value)) for (pname, pvalue) in self._data.parameter.items() }) return namespace class SignalContainer(ExperimentContainer): symbols = symbolcycle() def __init__(self, identifier, data, symbol=None, **kwargs): super().__init__(identifier, data, **kwargs) self.mode = 'signal' self._init_plot(symbol=symbol, **kwargs) def _init_plot(self, **kwargs): self.plot_real = PlotItem(x=self.x, y=self.y.real, name=self.name, symbol=None, pen=None, connect='finite') self.plot_imag = PlotItem(x=self.x, y=self.y.imag, name=self.name, symbol=None, pen=None, connect='finite') color = kwargs.get('color', None) symcolor = kwargs.get('symbolcolor', color) linecolor = kwargs.get('linecolor', color) if symcolor is None and linecolor is None: color = next(self.colors) if color is None else color symcolor = color linecolor = color elif symcolor is None: symcolor = linecolor elif linecolor is None: linecolor = symcolor sym_kwargs = { 'symbol': kwargs.get('symbol', None), 'size': kwargs.get('symbolsize', 10), 'color': symcolor } line_kwargs = { 'style': kwargs.get('linestyle', None), 'width': kwargs.get('linewidth', 1), 'color': linecolor } if isinstance(self._data, BDS): self.mode = 'bds' if sym_kwargs['symbol'] is None and line_kwargs['style'] is None: if len(self._data) <= 91: sym_kwargs['symbol'] = next(PointContainer.symbols) line_kwargs['style'] = LineStyle.No else: line_kwargs['style'] = LineStyle.Solid sym_kwargs['symbol'] = SymbolStyle.No elif isinstance(self._data, Signal): if line_kwargs['style'] is None and sym_kwargs['symbol'] is None: line_kwargs['style'] = LineStyle.Solid sym_kwargs['symbol'] = SymbolStyle.No if isinstance(self._data, FID): self.mode = 'fid' else: self.mode = 'spectrum' else: raise TypeError('Unknown class %s, should be FID, Spectrum, or BDS.' % type(self._data)) for mode in ['real', 'imag']: if mode == 'imag' and line_kwargs['style'] != LineStyle.No: line_kwargs['style'] = LineStyle.Dashed self.setSymbol(mode=mode, **sym_kwargs) self.setLine(mode=mode, **line_kwargs) def _update_actions(self): super()._update_actions() self.actions.update({'ph': self._data.manual_phase, 'bls': self._data.baseline_spline, 'autoph': self._data.autophase}) if isinstance(self._data, Spectrum): self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift, 'divide': self._data.divide, 'ft': self.fourier}) self.mode = 'spectrum' elif isinstance(self._data, FID): self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift, 'zf': self._data.zerofill, 'divide': self._data.divide, 'ap': self._data.apod, 'ft': self.fourier}) self.mode = 'fid' @plot_update def fourier(self, mode='normal'): if mode == 'normal': self._data = self._data.fourier() elif mode == 'depake': try: self._data = self._data.fft_depake() except AttributeError: return self._update_actions() return self