nmreval/src/gui_qt/data/container.py

689 lines
24 KiB
Python
Raw Normal View History

2022-10-20 15:23:15 +00:00
from __future__ import annotations
2022-03-08 09:27:40 +00:00
from collections import OrderedDict
from itertools import cycle
from typing import Any
import numpy as np
from pyqtgraph import mkPen
2022-10-20 15:23:15 +00:00
from nmreval.data.points import Points
from nmreval.data.signals import Signal
from nmreval.utils.text import convert
from nmreval.data.bds import BDS
from nmreval.data.dsc import DSC
2022-10-20 15:23:15 +00:00
from nmreval.lib.colors import BaseColor, TUColors
from nmreval.lib.lines import LineStyle
from nmreval.lib.symbols import SymbolStyle, symbolcycle
from nmreval.data.nmr import Spectrum, FID
2022-03-08 09:27:40 +00:00
from ..Qt import QtCore, QtGui
from ..io.exporters import GraceExporter
2022-03-08 09:27:40 +00:00
from ..lib.decorators import plot_update
from ..lib.pg_objects import ErrorBars, PlotItem
class ExperimentContainer(QtCore.QObject):
dataChanged = QtCore.pyqtSignal(str)
labelChanged = QtCore.pyqtSignal(str, str)
groupChanged = QtCore.pyqtSignal(str, str)
colors = cycle(TUColors)
2022-03-08 09:27:40 +00:00
def __init__(self, identifier, data, **kwargs):
super().__init__()
self.id = str(identifier)
self._fits = []
self._data = data
self._manager = kwargs.get('manager')
self.graph = ''
2022-03-08 09:27:40 +00:00
self.mode = 'point'
self.plot_real = None
self.plot_imag = None
self.plot_error = None
self.actions = {}
self._update_actions()
2023-06-19 16:15:25 +00:00
@plot_update
2022-03-08 09:27:40 +00:00
def _init_plot(self):
raise NotImplementedError
def __getitem__(self, item):
try:
return self._data[item]
except KeyError:
2022-03-24 19:24:28 +00:00
raise KeyError('Unknown key %s' % str(item)) from None
2022-03-08 09:27:40 +00:00
def __del__(self):
del self._data
del self.plot_real
del self.plot_imag
del self.plot_error
def __repr__(self):
return 'name:' + self.name
def __len__(self):
return len(self._data)
def copy(self, full: bool = False, keep_color: bool = True):
2022-03-08 09:27:40 +00:00
if full:
pen_dict = {}
if keep_color:
pen_dict = {
'symbol': self.plot_real.symbol,
'symbolcolor': self.plot_real.symbolcolor,
'symbolsize': self.plot_real.symbolsize,
'linestyle': self.plot_real.linestyle,
'linecolor': self.plot_real.linecolor,
'linewidth': self.plot_real.linewidth,
}
new_data = type(self)(str(self.id), self._data.copy(), manager=self._manager, **pen_dict)
2022-03-08 09:27:40 +00:00
new_data.mode = self.mode
if keep_color and self.plot_imag is not None:
new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize,
color=self.plot_imag.symbolcolor)
new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth,
color=self.plot_imag.linecolor)
2022-03-08 09:27:40 +00:00
return new_data
else:
return self._data.copy()
def change_type(self, data):
if isinstance(data, (FID, Spectrum, BDS)):
new_type = SignalContainer
elif isinstance(data, Points):
new_type = PointContainer
else:
raise TypeError('Unknown data type')
2022-03-24 19:24:28 +00:00
# pen_dict = {
# 'symbol': self.plot_real.symbol,
# 'symbolcolor': self.plot_real.symbolcolor,
# 'symbolsize': self.plot_real.symbolsize,
# 'linestyle': self.plot_real.linestyle,
# 'linecolor': self.plot_real.linecolor,
# 'linewidth': self.plot_real.linewidth,
# }
2022-03-08 09:27:40 +00:00
new_data = new_type(str(self.id), data, manager=self._manager)
2022-03-24 19:24:28 +00:00
# if new_data.plot_imag is not None:
# if self.plot_imag is not None:
# new_data.plot_imag.set_symbol(symbol=self.plot_imag.symbol, size=self.plot_imag.symbolsize,
# color=self.plot_imag.symbolcolor)
# new_data.plot_imag.set_line(style=self.plot_imag.linestyle, width=self.plot_imag.linewidth,
# color=self.plot_imag.linecolor)
# else:
# new_data.plot_imag.set_symbol(symbol=self.plot_real.symbol, size=self.plot_real.symbolsize,
# color=self.plot_real.symbolcolor)
# new_data.plot_imag.set_line(style=self.plot_real.linestyle, width=self.plot_real.linewidth,
# color=self.plot_real.linecolor)
2022-03-08 09:27:40 +00:00
return new_data
@property
def x(self):
return self._data.x[self._data.mask]
@x.setter
@plot_update
def x(self, value):
if len(self._data.x) == len(value):
self._data.x = value
elif len(self._data.x[self._data.mask]) == len(value):
self._data.x = value
self._data.y = self._data.y[self._data.mask]
self._data.mask = np.ma.array(np.ones_like(self._data.x, dtype=bool))
else:
raise ValueError('x and y have different dimensions!')
@property
def y(self):
return self._data.y[self._data.mask]
@y.setter
@plot_update
def y(self, value):
if len(self._data.y) == len(value):
self._data.y = value
elif len(self._data.y[self._data.mask]) == len(value):
self._data.y = value
self._data.x = self._data.x[self._data.mask]
self._data.mask = np.ma.array(np.ones_like(self._data.y, dtype=bool))
else:
raise ValueError('x and y have different dimensions!')
@property
def y_err(self):
return self._data.y_err[self._data.mask]
@y_err.setter
@plot_update
def y_err(self, value):
if len(self._data.y_err) == len(value):
self._data.y_err = value
elif len(self._data.y[self._data.mask]) == len(value):
self._data.y_err[self._data.mask] = value
else:
raise ValueError('y_err has not correct length')
@property
def name(self):
return self._data.name
@name.setter
@plot_update
def name(self, value: str):
self._data.name = value
self.plot_real.opts['name'] = value
try:
self.plot_imag.opts['name'] = value
except AttributeError:
pass
try:
num_val = float(value)
self._data.value = num_val
except ValueError:
pass
@property
def value(self):
return self._data.value
@value.setter
def value(self, val):
self._data.value = float(val)
@property
def group(self):
return str(self._data['group'])
@group.setter
def group(self, valium):
self._data['group'] = str(valium)
self.groupChanged.emit(self.id, str(valium))
@property
def data(self):
return self._data
@data.setter
@plot_update
def data(self, new_data):
self._data = new_data
self._update_actions()
@property
def opts(self):
return self._data.meta
@property
def plots(self):
return self.plot_real, self.plot_imag, self.plot_error
def get_state(self):
ret_dic = {
'id': self.id,
'data': self._data.get_state(),
'mode': self.mode,
'fits': self._fits,
'real': ({'symbol': self.plot_real.symbol.value,
'size': self.plot_real.symbolsize,
'color': self.plot_real.symbolcolor.value},
{'style': self.plot_real.linestyle.value,
'width': self.plot_real.linewidth,
'color': self.plot_real.linecolor.value})
}
if self.plot_imag is not None:
ret_dic['imag'] = ({'symbol': self.plot_imag.symbol.value,
'size': self.plot_imag.symbolsize,
'color': self.plot_imag.symbolcolor.value},
{'style': self.plot_imag.linestyle.value,
'width': self.plot_imag.linewidth,
'color': self.plot_imag.linecolor.value})
return ret_dic
def get_fits(self):
return [self._manager[idx] for idx in self._fits]
def has_fits(self):
return len(self._fits) != 0
2022-10-20 15:23:15 +00:00
def set_fits(self, value: str | list, replace: bool = False):
2022-03-08 09:27:40 +00:00
if isinstance(value, str):
value = [value]
if replace:
if isinstance(value, list):
self._fits = value
else:
raise TypeError()
else:
self._fits.extend(value)
def _update_actions(self):
self.actions.update({'sort': self._data.sort,
'cut': self._data.cut,
'norm': self._data.normalize,
'center': self.center})
@plot_update
def update(self, opts: dict):
self._data.update(opts)
def get_properties(self) -> dict:
props = OrderedDict()
props['General'] = OrderedDict([('Name', self.name),
('Value', str(self.value)),
('Group', str(self.group))])
2022-03-08 09:27:40 +00:00
props['Symbol'] = OrderedDict()
props['Line'] = OrderedDict()
props['Symbol']['Symbol'] = self.plot_real.symbol
props['Symbol']['Size'] = self.plot_real.symbolsize
props['Symbol']['Color'] = self.plot_real.symbolcolor
props['Line']['Style'] = self.plot_real.linestyle
props['Line']['Width'] = self.plot_real.linewidth
props['Line']['Color'] = self.plot_real.linecolor
if self.plot_imag is not None:
props['Symbol']['Symbol (imag)'] = self.plot_imag.symbol
props['Symbol']['Size (imag)'] = self.plot_imag.symbolsize
props['Symbol']['Color (imag)'] = self.plot_imag.symbolcolor
props['Line']['Style (imag)'] = self.plot_imag.linestyle
props['Line']['Width (imag)'] = self.plot_imag.linewidth
props['Line']['Color (imag)'] = self.plot_imag.linecolor
return props
def setColor(self, color, symbol=False, line=False, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_color(color, symbol=symbol, line=line)
if self.plot_error is not None:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb()))
2022-03-08 09:27:40 +00:00
self.plot_error.setData(pen=err_pen)
if mode in ['imag', 'all'] and self.plot_imag is not None:
2022-03-08 09:27:40 +00:00
self.plot_imag.set_color(color, symbol=symbol, line=line)
def setSymbol(self, symbol=None, color=None, size=None, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_symbol(symbol=symbol, size=size, color=color)
if color is not None and self.plot_error is not None and self.plot_real.symbol != SymbolStyle.No:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.symbolcolor.rgb()))
self.plot_error.setData(pen=err_pen)
elif mode in ['imag', 'all'] and self.plot_imag is not None:
2022-03-08 09:27:40 +00:00
self.plot_imag.set_symbol(symbol=symbol, size=size, color=color)
else:
print('Updating symbol failed for ' + str(self.id))
def setLine(self, width=None, style=None, color=None, mode='real'):
if mode in ['real', 'all']:
2022-03-08 09:27:40 +00:00
self.plot_real.set_line(width=width, style=style, color=color)
if color is not None and self.plot_error is not None and self.plot_real.symbol == SymbolStyle.No:
err_pen = self.plot_error.opts['pen']
err_pen.setColor(QtGui.QColor(*self.plot_real.linecolor.rgb()))
self.plot_error.setData(pen=err_pen)
elif mode in ['imag', 'all'] and self.plot_imag is not None:
2022-03-08 09:27:40 +00:00
self.plot_imag.set_line(width=width, style=style, color=color)
else:
print('Updating line failed for ' + str(self.id))
def update_property(self, key1: str, key2: str, value: Any):
keykey = key2.split()
if len(keykey) == 1:
if key1 == 'Symbol':
self.setSymbol(mode='real', **{key2.lower(): value})
elif key1 == 'Line':
self.setLine(mode='real', **{key2.lower(): value})
elif key1 == 'General':
setattr(self, key2.lower(), value)
else:
if key1 == 'Symbol':
self.setSymbol(mode='imag', **{keykey[0].lower(): value})
elif key1 == 'Line':
self.setLine(mode='imag', **{keykey[0].lower(): value})
def points(self, params: dict):
return self._data.points(**params)
@plot_update
def apply(self, func: str, args: tuple):
if func in self.actions:
f = self.actions[func]
f(*args)
return self
@plot_update
def unsort(self, order: np.ndarray):
# this exists only to update plots after an undo action
self._data.x = self._data.x[order]
self._data.y = self._data.y[order]
self._data.y_err = self._data.y_err[order]
self._data.mask = self._data.mask[order]
def save(self, fname):
ext = fname.suffix
if ext == '.agr':
dic = self._manager.graphs[self.graph].get_state()
dic['items'] = [self.plot_real.get_data_opts()]
if self.plot_imag is not None:
dic['items'].append(self.plot_imag.get_data_opts())
2022-03-08 09:27:40 +00:00
GraceExporter(dic).export(fname)
2022-03-08 09:27:40 +00:00
elif ext in ['.dat', '.txt']:
self._data.savetxt(fname, err=True)
else:
raise ValueError('Unknown extension ' + ext)
@plot_update
def setvalues(self, pos, valium):
xy, position = pos
if xy == 0:
self._data.x[position] = valium
elif xy == 1:
self._data.y[position] = valium
else:
self._data.y_err[position] = valium
@property
def mask(self):
return self._data.mask
@mask.setter
@plot_update
def mask(self, m):
self._data.mask = np.asarray(m, dtype=bool)
@plot_update
def add(self, m):
if isinstance(m, (np.ndarray, list, tuple)):
self._data.append(m[0], m[1], y_err=m[2])
elif isinstance(m, (Points, ExperimentContainer)):
self._data.append(m.x, m.y, y_err=m.y_err)
else:
raise TypeError('Unknown type ' + type(m))
@plot_update
def remove(self, m):
self._data.remove(m)
@plot_update
def center(self) -> float:
offset = self.x[np.argmax(self.y.real)]
self._data._x -= offset
return offset
def get_namespace(self, i: int = None, j: int = None) -> dict:
if (i is None) and (j is None):
prefix = ''
else:
prefix = 'g[%i].s[%i].' % (i, j)
namespace = {prefix + 'x': (self.x, 'x values'),
prefix + 'y': [self.y, 'y values'],
prefix + 'y_err': (self.y_err, 'y error values'),
prefix + 'value': (self.value, str(self.value))}
if len(self._fits) == 1:
namespace.update({
"%sfit['%s']" % (prefix, convert(pname, old='tex', new='str')): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._manager[self._fits[0]].parameter.items()
})
else:
for k, f in enumerate(self._fits):
namespace.update({
"%sfit['%s_%d']" % (prefix, convert(pname, old='tex', new='str'), k): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._manager[f].parameter.items()
})
return namespace
def eval_expression(self, cmds, namespace):
namespace.update({'x': self._data.x, 'y': self._data.y, 'y_err': self._data.y_err, 'value': self.value})
2022-03-08 09:27:40 +00:00
if len(self._fits) == 1:
namespace.update({"fit['%s']" % (convert(pname, old='tex', new='str')): pvalue.value
for (pname, pvalue) in self._manager[self._fits[0]].parameter.items()})
else:
for k, f in enumerate(self._fits):
namespace.update({"fit['%s_%i']" % (convert(pname, old='tex', new='str'), k): pvalue.value
for (pname, pvalue) in self._manager[f].parameter.items()})
new_data = self.copy()
for c in cmds:
if c:
exec(c, globals(), namespace)
new_data.set_data(x=namespace['x'], y=namespace['y'], y_err=namespace['y_err'], replace_mask=False)
2022-03-08 09:27:40 +00:00
new_data.value = namespace['value']
return new_data
class PointContainer(ExperimentContainer):
symbols = symbolcycle()
def __init__(self, identifier, data, **kwargs):
super().__init__(identifier, data, **kwargs)
self.mode = 'pts'
self._init_plot(**kwargs)
if isinstance(self._data, DSC):
self.mode = 'dsc'
2022-03-08 09:27:40 +00:00
def _init_plot(self, **kwargs):
self.plot_imag = None
color = kwargs.get('color', None)
symcolor = kwargs.get('symbolcolor', color)
linecolor = kwargs.get('linecolor', color)
if symcolor is None and linecolor is None:
color = next(self.colors)
symcolor = color
linecolor = color
elif symcolor is None:
symcolor = linecolor
elif linecolor is None:
linecolor = symcolor
sym_kwargs = {
'symbol': kwargs.get('symbol', None),
'size': kwargs.get('symbolsize', 10),
'color': symcolor
}
line_kwargs = {
'style': kwargs.get('linestyle', None),
'width': kwargs.get('linewidth', 1),
'color': linecolor
}
if sym_kwargs['symbol'] is None and line_kwargs['style'] is None:
2022-03-24 16:35:10 +00:00
if len(self._data) > 500:
2022-03-08 09:27:40 +00:00
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
else:
line_kwargs['style'] = LineStyle.No
sym_kwargs['symbol'] = next(PointContainer.symbols)
2023-06-19 16:15:25 +00:00
self.plot_real = PlotItem(x=self.x, y=self.y, name=self.name,
2022-03-08 09:27:40 +00:00
symbol=None, pen=None, connect='finite')
self.setSymbol(mode='real', **sym_kwargs)
self.setLine(mode='real', **line_kwargs)
if sym_kwargs['symbol'] != SymbolStyle.No:
2023-06-19 16:15:25 +00:00
self.plot_error = ErrorBars(x=self.x, y=self.y, top=self.y_err, bottom=self.y_err,
2022-03-08 09:27:40 +00:00
pen=mkPen({'color': self.plot_real.symbolcolor.rgb()}))
else:
2023-06-19 16:15:25 +00:00
self.plot_error = ErrorBars(x=self.x, y=self.y, top=self.y_err, bottom=self.y_err,
2022-03-08 09:27:40 +00:00
pen=mkPen({'color': self.plot_real.linecolor.rgb()}))
class FitContainer(ExperimentContainer):
def __init__(self, identifier, data, **kwargs):
super().__init__(identifier, data, **kwargs)
self.fitted_key = kwargs.get('src', '')
self.mode = 'fit'
self.parent_set = kwargs.get('src', '')
self._init_plot(**kwargs)
for n in ['statistics', 'nobs', 'nvar', 'parameter', 'model_name']:
setattr(self, n, getattr(data, n))
def _init_plot(self, **kwargs):
color = kwargs.get('color')
if color is None:
color = kwargs.get('linecolor', (0, 0, 0))
if isinstance(color, BaseColor):
2022-03-08 09:27:40 +00:00
color = color.rgb()
2023-06-19 16:15:25 +00:00
self.plot_real = PlotItem(x=self.x, y=self.y.real, name=self.name,
2022-03-08 09:27:40 +00:00
pen=mkPen({'color': color}),
connect='finite', symbol=None)
if np.iscomplexobj(self._data.y):
2023-06-19 16:15:25 +00:00
self.plot_imag = PlotItem(x=self.x, y=self.y.imag, name=self.name,
2022-03-08 09:27:40 +00:00
pen=mkPen({'color': color}),
connect='finite', symbol=None)
@property
def fitted_key(self):
return self._data.idx
@fitted_key.setter
def fitted_key(self, val):
self._data.idx = val
def get_namespace(self, i: int = None, j: int = None):
namespace = super().get_namespace(i, j)
namespace.update({
"g[%i].s[%i].fit['%s']" % (i, j, convert(pname, old='latex', new='plain')): (pvalue.value, str(pvalue.value))
for (pname, pvalue) in self._data.parameter.items()
})
return namespace
class SignalContainer(ExperimentContainer):
symbols = symbolcycle()
def __init__(self, identifier, data, symbol=None, **kwargs):
super().__init__(identifier, data, **kwargs)
self.mode = 'signal'
self._init_plot(symbol=symbol, **kwargs)
def _init_plot(self, **kwargs):
2023-06-19 16:15:25 +00:00
self.plot_real = PlotItem(x=self.x, y=self.y.real, name=self.name,
2022-03-08 09:27:40 +00:00
symbol=None, pen=None, connect='finite')
2023-06-19 16:15:25 +00:00
self.plot_imag = PlotItem(x=self.x, y=self.y.imag, name=self.name,
2022-03-08 09:27:40 +00:00
symbol=None, pen=None, connect='finite')
color = kwargs.get('color', None)
symcolor = kwargs.get('symbolcolor', color)
linecolor = kwargs.get('linecolor', color)
if symcolor is None and linecolor is None:
color = next(self.colors) if color is None else color
2022-03-08 09:27:40 +00:00
symcolor = color
linecolor = color
elif symcolor is None:
symcolor = linecolor
elif linecolor is None:
linecolor = symcolor
sym_kwargs = {
'symbol': kwargs.get('symbol', None),
'size': kwargs.get('symbolsize', 10),
'color': symcolor
}
line_kwargs = {
'style': kwargs.get('linestyle', None),
'width': kwargs.get('linewidth', 1),
'color': linecolor
}
if isinstance(self._data, BDS):
self.mode = 'bds'
if sym_kwargs['symbol'] is None and line_kwargs['style'] is None:
2022-03-22 19:07:59 +00:00
if len(self._data) <= 91:
sym_kwargs['symbol'] = next(PointContainer.symbols)
line_kwargs['style'] = LineStyle.No
else:
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
2022-03-08 09:27:40 +00:00
elif isinstance(self._data, Signal):
if line_kwargs['style'] is None and sym_kwargs['symbol'] is None:
line_kwargs['style'] = LineStyle.Solid
sym_kwargs['symbol'] = SymbolStyle.No
if isinstance(self._data, FID):
self.mode = 'fid'
else:
self.mode = 'spectrum'
else:
raise TypeError('Unknown class %s, should be FID, Spectrum, or BDS.' % type(self._data))
for mode in ['real', 'imag']:
2022-03-22 19:07:59 +00:00
if mode == 'imag' and line_kwargs['style'] != LineStyle.No:
2022-03-08 09:27:40 +00:00
line_kwargs['style'] = LineStyle.Dashed
self.setSymbol(mode=mode, **sym_kwargs)
self.setLine(mode=mode, **line_kwargs)
def _update_actions(self):
super()._update_actions()
2023-01-07 18:13:13 +00:00
self.actions.update({'ph': self._data.manual_phase, 'bls': self._data.baseline_spline,
'autoph': self._data.autophase})
2022-03-08 09:27:40 +00:00
if isinstance(self._data, Spectrum):
self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift,
'divide': self._data.divide, 'ft': self.fourier})
self.mode = 'spectrum'
elif isinstance(self._data, FID):
self.actions.update({'bl': self._data.baseline, 'ls': self._data.shift,
'zf': self._data.zerofill, 'divide': self._data.divide,
'ap': self._data.apod, 'ft': self.fourier})
self.mode = 'fid'
@plot_update
def fourier(self, mode='normal'):
if mode == 'normal':
self._data = self._data.fourier()
elif mode == 'depake':
try:
self._data = self._data.fft_depake()
except AttributeError:
return
self._update_actions()
return self