nmreval/src/gui_qt/data/container.py

683 lines
24 KiB
Python
Raw Normal View History

2022-10-20 15:23:15 +00:00
from __future__ import annotations
2022-03-08 09:27:40 +00:00
from collections import OrderedDict
from itertools import cycle
from typing import Any
import numpy as np
from pyqtgraph import mkPen
2022-10-20 15:23:15 +00:00
from nmreval.data.points import Points
from nmreval.data.signals import Signal
from nmreval.utils.text import convert
from nmreval.data.bds import BDS
from nmreval.lib.colors import BaseColor, TUColors
from nmreval.lib.lines import LineStyle
from nmreval.lib.symbols import SymbolStyle, symbolcycle
from nmreval.data.nmr import Spectrum, FID
2022-03-08 09:27:40 +00:00
from ..Qt import QtCore, QtGui
from ..io.exporters import GraceExporter
2022-03-08 09:27:40 +00:00
from ..lib.decorators import plot_update
from ..lib.pg_objects import ErrorBars, PlotItem
class ExperimentContainer(QtCore.QObject):
dataChanged = QtCore.pyqtSignal(str)
labelChanged = QtCore.pyqtSignal(str, str)
groupChanged = QtCore.pyqtSignal(str, str)
colors = cycle(TUColors)
2022-03-08 09:27:40 +00:00
def __init__(self, identifier, data, **kwargs):
super().__init__()
self.id = str(identifier)
self._fits = []
self._data = data
self._manager = kwargs.get('manager')
self.graph = ''
2022-03-08 09:27:40 +00:00
self.mode = 'point'
self.plot_real = None
self.plot_imag = None
self.plot_error = None
self.actions = {}
self._update_actions()
def _init_plot(self):
raise NotImplementedError
def __getitem__(self, item):
try:
return self._data[item]
except KeyError:
2022-03-24 19:24:28 +00:00
raise KeyError('Unknown key %s' % str(item)) from None
2022-03-08 09:27:40 +00:00
def __del__(self):
del self._data
del self.plot_real
del self.plot_imag
del self.plot_error
def __repr__(self):
return 'name:' + self.name
def __len__(self):
return len(self._data)
def copy(self, full: bool = False, keep_color: bool = True):
2022-03-08 09:27:40 +00:00
if full:
pen_dict = {}
if keep_color:
pen_dict = {
'symbol': self.plot_real.symbol,
'symbolcolor': self.plot_real.symbolcolor,
'symbolsize': self.plot_real.symbolsize,
'linestyle': self.plot_real.linestyle,
'linecolor': self.plot_real.linecolor,
'linewidth': self.plot_real.linewidth,
}
new_data = type(self)(str(self.id), self._data.copy(), manager=self._manager, **pen_dict)
2022-03-08 09:27:40 +00:00
new_data.mode = self.mode
if keep_color and self.plot_imag is not None:
new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize,
color=self.plot_imag.symbolcolor)
new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth,
color=self.plot_imag.linecolor)
2022-03-08 09:27:40 +00:00
return new_data
else:
return self._data.copy()
def change_type(self, data):
if isinstance(data, (FID, Spectrum, BDS)):
new_type = SignalContainer
elif isinstance(data, Points):
new_type = PointContainer
else:
raise TypeError('Unknown data type')
2022-03-24 19:24:28 +00:00
# pen_dict = {
# 'symbol': self.plot_real.symbol,
# 'symbolcolor': self.plot_real.symbolcolor,
# 'symbolsize': self.plot_real.symbolsize,
# 'linestyle': self.plot_real.linestyle,
# 'linecolor': self.plot_real.linecolor,
# 'linewidth': self.plot_real.linewidth,
# }
2022-03-08 09:27:40 +00:00
new_data = new_type(str(self.id), data, manager=self._manager)
2022-03-24 19:24:28 +00:00
# if new_data.plot_imag is not None:
# if self.plot_imag is not None:
# new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize,
# color=self.plot_imag.symbolcolor)
# new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth,
# color=self.plot_imag.linecolor)
# else:
# new_data.plot_imag.set_symbol(symbol=self.plot_real.symbol, size=self.plot_real.symbolsize,
# color=self.plot_real.symbolcolor)
# new_data.plot_imag.set_line(style=self.plot_real.linestyle, width=self.plot_real.linewidth,
# color=self.plot_real.linecolor)
2022-03-08 09:27:40 +00:00
return new_data
@property
def x(self):
return self._data.x[self._data.mask]
@x.setter
@plot_update
def x(self, value):
if len(self._data.x) == len(value):
self._data.x = value
elif len(self._data.x[self._data.mask]) == len(value):
self._data.x = value
self._data.y = self._data.y[self._data.mask]
self._data.mask = np.ma.array(np.ones_like(self._data.x, dtype=bool))
else:
raise ValueError('x and y have different dimensions!')
@property
def y(self):
return self._data.y[self._data.mask]
@y.setter
@plot_update
def y(self, value):
if len(self._data.y) == len(value):
self._data.y = value
elif len(self._data.y[self._data.mask]) == len(value):
self._data.y = value
self._data.x = self._data.x[self._data.mask]
self._data.mask = np.ma.array(np.ones_like(self._data.y, dtype=bool))
else:
raise ValueError('x and y have different dimensions!')
@property
def y_err(self):
return self._data.y_err[self._data.mask]
@y_err.setter
@plot_update
def y_err(self, value):
if len(self._data.y_err) == len(value):
self._data.y_err = value
elif len(self._data.y[self._data.mask]) == len(value):
self._data.y_err[self._data.mask] = value
else:
raise ValueError('y_err has not correct length')
@property
def name(self):
return self._data.name
@name.setter
@plot_update
def name(self, value: str):
self._data.name = value
self.plot_real.opts['name'] = value
try:
self.plot_imag.opts['name'] = value
except AttributeError:
pass
try:
num_val = float(value)
self._data.value = num_val
except ValueError:
pass
@property
def value(self):
return self._data.value
@value.setter
def value(self, val):
self._data.value = float(val)
@property
def group(self):
return str(self._data['group'])
@group.setter
def group(self, valium):
self._data['group'] = str(valium)
self.groupChanged.emit(self.id, str(valium))
@property
def data(self):
return self._data
@data.setter
@plot_update
def data(self, new_data):
self._data = new_data
self._update_actions()
@property
def opts(self):
return self._data.meta
@property
def plots(self):
return self.plot_real, self.plot_imag, self.plot_error
def get_state(self):
ret_dic = {
'id': self.id,
'data': self._data.get_state(),
'mode': self.mode,
'fits': self._fits,
'real': ({'symbol': self.plot_real.symbol.value,
'size': self.plot_real.symbolsize,
'color': self.plot_real.symbolcolor.value},
{'style': self.plot_real.linestyle.value,
'width': self.plot_real.linewidth,
'color': self.plot_real.linecolor.value})
}
if self.plot_imag is not None:
ret_dic['imag'] = ({'symbol': self.plot_imag.symbol.value,
'size': self.plot_imag.symbolsize,
'color': self.plot_imag.symbolcolor.value},
{'style': self.plot_imag.linestyle.value,
'width': self.plot_imag.linewidth,
'color': self.plot_imag.linecolor.value})
return ret_dic
def get_fits(self):
return [self._manager[idx] for idx in self._fits]
def has_fits(self):
return len(self._fits) != 0
2022-10-20 15:23:15 +00:00
def set_fits(self, value: str | list, replace: bool = False):
2022-03-08 09:27:40 +00:00
if isinstance(value, str):
value = [value]
if replace:
if isinstance(value, list):
self._fits = value
else:
raise TypeError()
else:
self._fits.extend(value)
def _update_actions(self):
self.actions.update({'sort': self._data.sort,
'cut': self._data.cut,
'norm': self._data.normalize,
'center': self.center})
@plot_update
def update(self, opts: dict):
self._data.update(opts)
def get_properties(self) -> dict:
props = OrderedDict()
props['General'] = OrderedDict([('Name', self.name),
('Value', str(self.value)),
('Group', str(self.group))])
2022-03-08 09:27:40 +00:00
props['Symbol'] = OrderedDict()
props['Line'] = OrderedDict()
props['Symbol']['Symbol'] = self.plot_real.symbol
props['Symbol']['Size'] = self.plot_real.symbolsize
props['Symbol']['Color'] = self.plot_real.symbolcolor
props['Line']['Style'] = self.plot_real.linestyle
props['Line']['Width'] = self.plot_real.linewidth
props['Line']['Color'] = self.plot_real.linecolor
if self.plot_imag is not None:
props['Symbol']['Symbol (imag)'] = self.plot_imag.symbol
props['Symbol']['Size (imag)'] = self.plot_imag.symbolsize
props['Symbol']['Color (imag)'] = self.plot_imag.symbolcolor
props['Line']['Style (imag)'] = self.plot_imag.linestyle
props['Line']['Width (imag)'] = self.plot_imag.linewidth
props['Line']['Color (imag)'] = self.plot_imag.linecolor
return props
def setColor(self, color, symbol=False, line=False, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_color(color, symbol=symbol, line=line)
if self.plot_error is not None:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb()))
2022-03-08 09:27:40 +00:00
self.plot_error.setData(pen=err_pen)
elif mode == 'imag' and self.plot_imag is not None:
self.plot_imag.set_color(color, symbol=symbol, line=line)
else:
print('Updating color failed for ' + str(self.id))
2022-03-08 09:27:40 +00:00
def setSymbol(self, symbol=None, color=None, size=None, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_symbol(symbol=symbol, size=size, color=color)
if color is not None and self.plot_error is not None and self.plot_real.symbol != SymbolStyle.No:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb()))
self.plot_error.setData(pen=err_pen)
elif mode in ['imag', 'all'] and self.plot_imag is not None:
2022-03-08 09:27:40 +00:00
self.plot_imag.set_symbol(symbol=symbol, size=size, color=color)
else:
print('Updating symbol failed for ' + str(self.id))
def setLine(self, width=None, style=None, color=None, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_line(width=width, style=style, color=color)
if color is not None and self.plot_error is not None and self.plot_real.symbol == SymbolStyle.No:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.linecolor.rgb()))
self.plot_error.setData(pen=err_pen)
elif mode in ['imag', 'all'] and self.plot_imag is not None:
2022-03-08 09:27:40 +00:00
self.plot_imag.set_line(width=width, style=style, color=color)
else:
print('Updating line failed for ' + str(self.id))
def update_property(self, key1: str, key2: str, value: Any):
keykey = key2.split()
if len(keykey) == 1:
if key1 == 'Symbol':
self.setSymbol(mode='real', **{key2.lower(): value})
elif key1 == 'Line':
self.setLine(mode='real', **{key2.lower(): value})
elif key1 == 'General':
setattr(self, key2.lower(), value)
else:
if key1 == 'Symbol':
self.setSymbol(mode='imag', **{keykey[0].lower(): value})
elif key1 == 'Line':
self.setLine(mode='imag', **{keykey[0].lower(): value})
def points(self, params: dict):
return self._data.points(**params)
@plot_update
def apply(self, func: str, args: tuple):
if func in self.actions:
f = self.actions[func]
f(*args)
return self
@plot_update
def unsort(self, order: np.ndarray):
# this exists only to update plots after an undo action
self._data.x = self._data.x[order]
self._data.y = self._data.y[order]
self._data.y_err = self._data.y_err[order]
self._data.mask = self._data.mask[order]
def save(self, fname):
ext = fname.suffix
if ext == '.agr':
dic = self._manager.graphs[self.graph].get_state()
dic['items'] = [self.plot_real.get_data_opts()]
if self.plot_imag is not None:
dic['items'].append(self.plot_imag.get_data_opts())
2022-03-08 09:27:40 +00:00
GraceExporter(dic).export(fname)
2022-03-08 09:27:40 +00:00
elif ext in ['.dat', '.txt']:
self._data.savetxt(fname, err=True)
else:
raise ValueError('Unknown extension ' + ext)
@plot_update
def setvalues(self, pos, valium):
xy, position = pos
if xy == 0:
self._data.x[position] = valium
elif xy == 1:
self._data.y[position] = valium
else:
self._data.y_err[position] = valium
@property
def mask(self):
return self._data.mask
@mask.setter
@plot_update
def mask(self, m):
self._data.mask = np.asarray(m, dtype=bool)
@plot_update
def add(self, m):
if isinstance(m, (np.ndarray, list, tuple)):
self._data.append(m[0], m[1], y_err=m[2])
elif isinstance(m, (Points, ExperimentContainer)):
self._data.append(m.x, m.y, y_err=m.y_err)
else:
raise TypeError('Unknown type ' + type(m))
@plot_update
def remove(self, m):
self._data.remove(m)
@plot_update
def center(self) -> float:
offset = self.x[np.argmax(self.y.real)]
self._data._x -= offset
return offset
def get_namespace(self, i: int = None, j: int = None) -> dict:
if (i is None) and (j is None):
prefix = ''
else:
prefix = 'g[%i].s[%i].' % (i, j)
namespace = {prefix + 'x': (self.x, 'x values'),
prefix + 'y': [self.y, 'y values'],
prefix + 'y_err': (self.y_err, 'y error values'),
prefix + 'value': (self.value, str(self.value))}
if len(self._fits) == 1:
namespace.update({
"%sfit['%s']" % (prefix, convert(pname, old='tex', new='str')): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._manager[self._fits[0]].parameter.items()
})
else:
for k, f in enumerate(self._fits):
namespace.update({
"%sfit['%s_%d']" % (prefix, convert(pname, old='tex', new='str'), k): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._manager[f].parameter.items()
})
return namespace
def eval_expression(self, cmds, namespace):
namespace.update({'x': self.x, 'y': self.y, 'y_err': self.y_err, 'value': self.value})
if len(self._fits) == 1:
namespace.update({"fit['%s']" % (convert(pname, old='tex', new='str')): pvalue.value
for (pname, pvalue) in self._manager[self._fits[0]].parameter.items()})
else:
for k, f in enumerate(self._fits):
namespace.update({"fit['%s_%i']" % (convert(pname, old='tex', new='str'), k): pvalue.value
for (pname, pvalue) in self._manager[f].parameter.items()})
new_data = self.copy()
for c in cmds:
if c:
exec(c, globals(), namespace)
new_data.set_data(x=namespace['x'], y=namespace['y'], y_err=namespace['y_err'])
new_data.value = namespace['value']
return new_data
class PointContainer(ExperimentContainer):
symbols = symbolcycle()
def __init__(self, identifier, data, **kwargs):
super().__init__(identifier, data, **kwargs)
self.mode = 'pts'
self._init_plot(**kwargs)
def _init_plot(self, **kwargs):
self.plot_imag = None
color = kwargs.get('color', None)
symcolor = kwargs.get('symbolcolor', color)
linecolor = kwargs.get('linecolor', color)
if symcolor is None and linecolor is None:
color = next(self.colors)
symcolor = color
linecolor = color
elif symcolor is None:
symcolor = linecolor
elif linecolor is None:
linecolor = symcolor
sym_kwargs = {
'symbol': kwargs.get('symbol', None),
'size': kwargs.get('symbolsize', 10),
'color': symcolor
}
line_kwargs = {
'style': kwargs.get('linestyle', None),
'width': kwargs.get('linewidth', 1),
'color': linecolor
}
if sym_kwargs['symbol'] is None and line_kwargs['style'] is None:
2022-03-24 16:35:10 +00:00
if len(self._data) > 500:
2022-03-08 09:27:40 +00:00
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
else:
line_kwargs['style'] = LineStyle.No
sym_kwargs['symbol'] = next(PointContainer.symbols)
self.plot_real = PlotItem(x=self._data.x, y=self._data.y, name=self.name,
symbol=None, pen=None, connect='finite')
self.setSymbol(mode='real', **sym_kwargs)
self.setLine(mode='real', **line_kwargs)
if sym_kwargs['symbol'] != SymbolStyle.No:
self.plot_error = ErrorBars(x=self._data.x, y=self._data.y, top=self._data.y_err, bottom=self._data.y_err,
pen=mkPen({'color': self.plot_real.symbolcolor.rgb()}))
else:
self.plot_error = ErrorBars(x=self._data.x, y=self._data.y, top=self._data.y_err, bottom=self._data.y_err,
pen=mkPen({'color': self.plot_real.linecolor.rgb()}))
class FitContainer(ExperimentContainer):
def __init__(self, identifier, data, **kwargs):
super().__init__(identifier, data, **kwargs)
self.fitted_key = kwargs.get('src', '')
self.mode = 'fit'
self.parent_set = kwargs.get('src', '')
self._init_plot(**kwargs)
for n in ['statistics', 'nobs', 'nvar', 'parameter', 'model_name']:
setattr(self, n, getattr(data, n))
def _init_plot(self, **kwargs):
color = kwargs.get('color', (0, 0, 0))
if isinstance(color, BaseColor):
2022-03-08 09:27:40 +00:00
color = color.rgb()
self.plot_real = PlotItem(x=self._data.x, y=self._data.y.real, name=self.name,
pen=mkPen({'color': color}),
connect='finite', symbol=None)
if np.iscomplexobj(self._data.y):
self.plot_imag = PlotItem(x=self._data.x, y=self._data.y.imag, name=self.name,
pen=mkPen({'color': color}),
connect='finite', symbol=None)
@property
def fitted_key(self):
return self._data.idx
@fitted_key.setter
def fitted_key(self, val):
self._data.idx = val
def get_namespace(self, i: int = None, j: int = None):
namespace = super().get_namespace(i, j)
namespace.update({
"g[%i].s[%i].fit['%s']" % (i, j, convert(pname, old='latex', new='plain')): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._data.parameter.items()
})
return namespace
class SignalContainer(ExperimentContainer):
symbols = symbolcycle()
def __init__(self, identifier, data, symbol=None, **kwargs):
super().__init__(identifier, data, **kwargs)
self.mode = 'signal'
self._init_plot(symbol=symbol, **kwargs)
def _init_plot(self, **kwargs):
self.plot_real = PlotItem(x=self._data.x, y=self._data.y.real, name=self.name,
symbol=None, pen=None, connect='finite')
self.plot_imag = PlotItem(x=self._data.x, y=self._data.y.imag, name=self.name,
symbol=None, pen=None, connect='finite')
color = kwargs.get('color', None)
symcolor = kwargs.get('symbolcolor', color)
linecolor = kwargs.get('linecolor', color)
if symcolor is None and linecolor is None:
color = next(self.colors)
symcolor = color
linecolor = color
elif symcolor is None:
symcolor = linecolor
elif linecolor is None:
linecolor = symcolor
sym_kwargs = {
'symbol': kwargs.get('symbol', None),
'size': kwargs.get('symbolsize', 10),
'color': symcolor
}
line_kwargs = {
'style': kwargs.get('linestyle', None),
'width': kwargs.get('linewidth', 1),
'color': linecolor
}
if isinstance(self._data, BDS):
self.mode = 'bds'
if sym_kwargs['symbol'] is None and line_kwargs['style'] is None:
2022-03-22 19:07:59 +00:00
if len(self._data) <= 91:
sym_kwargs['symbol'] = next(PointContainer.symbols)
line_kwargs['style'] = LineStyle.No
else:
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
2022-03-08 09:27:40 +00:00
elif isinstance(self._data, Signal):
if line_kwargs['style'] is None and sym_kwargs['symbol'] is None:
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
if isinstance(self._data, FID):
self.mode = 'fid'
else:
self.mode = 'spectrum'
else:
raise TypeError('Unknown class %s, should be FID, Spectrum, or BDS.' % type(self._data))
for mode in ['real', 'imag']:
2022-03-22 19:07:59 +00:00
if mode == 'imag' and line_kwargs['style'] != LineStyle.No:
2022-03-08 09:27:40 +00:00
line_kwargs['style'] = LineStyle.Dashed
self.setSymbol(mode=mode, **sym_kwargs)
self.setLine(mode=mode, **line_kwargs)
def _update_actions(self):
super()._update_actions()
self.actions.update({'ph': self._data.manual_phase, 'bls': self._data.baseline_spline})
if isinstance(self._data, Spectrum):
self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift,
'divide': self._data.divide, 'ft': self.fourier})
self.mode = 'spectrum'
elif isinstance(self._data, FID):
self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift,
'zf': self._data.zerofill, 'divide': self._data.divide,
'ap': self._data.apod, 'ft': self.fourier})
self.mode = 'fid'
@plot_update
def fourier(self, mode='normal'):
if mode == 'normal':
self._data = self._data.fourier()
elif mode == 'depake':
try:
self._data = self._data.fft_depake()
except AttributeError:
return
self._update_actions()
return self