from __future__ import annotations from pyqtgraph import mkPen import numpy as np from numpy import pi from numpy.fft import fft, fftshift, fftfreq from nmreval.data import FID, Spectrum from ...lib.pg_objects import PlotItem, LogInfiniteLine from nmreval.lib.importer import find_models from nmreval.math import apodization as apodization from nmreval.utils.text import convert from ...Qt import QtCore, QtWidgets, QtGui from ..._py.apod_dialog import Ui_ApodEdit from ...lib.forms import FormWidget class QPreviewDialog(QtWidgets.QDialog, Ui_ApodEdit): finished = QtCore.pyqtSignal(str, tuple) def __init__(self, parent=None): super().__init__(parent=parent) self.setupUi(self) self.data = [] self.graphs = [] self._tmp_data_bl = [] self._tmp_data_zf = [] self._tmp_data_ls = [] self._tmp_data_ap = [] self._tmp_data_ph = [] self.pvt_line = LogInfiniteLine(pos=0, movable=True) self.freq_graph.addItem(self.pvt_line) self.pvt_line.sigPositionChanged.connect(self.move_line) self.ls_lineedit.hide() self.apods = find_models(apodization) self.apodcombobox.blockSignals(True) for ap in self.apods: self.apodcombobox.addItem(ap().name) self.apodcombobox.blockSignals(False) self.apod_graph = PlotItem(x=[], y=[]) self.time_graph.addItem(self.apod_graph) for g in [self.freq_graph, self.time_graph]: pl = g.getPlotItem() pl.hideButtons() pl.setMenuEnabled(False) self._all_time = None self._all_freq = None self.change_apodization(0) self.shift_box.clicked.connect(self._update_shift) self.ls_spinbox.valueChanged.connect(self._update_shift) self.ls_lineedit.setValidator(QtGui.QDoubleValidator()) self.ls_lineedit.textChanged.connect(self._update_shift) self.zerofill_box.clicked.connect(self._update_zf) self.zf_spinbox.valueChanged.connect(self._update_zf) self.apod_box.clicked.connect(self._update_apod) self.phase_box.clicked.connect(self._update_phase) self.ph0_spinbox.valueChanged.connect(self._update_phase) self.ph1_spinbox.valueChanged.connect(self._update_phase) self.pivot_lineedit.setValidator(QtGui.QDoubleValidator()) self.pivot_lineedit.textChanged.connect(self._update_phase) self.pivot_lineedit.textEdited.connect(lambda x: self.pvt_line.setValue(float(x))) def add_data(self: QPreviewDialog, data: FID | Spectrum) -> bool: if isinstance(data, FID): if self._all_freq: msg = QtWidgets.QMessageBox.warning(self, 'Mixed types', 'Timesignals and spectra cannot be edited at the same time.') return False else: self._all_time = True self._all_freq = False elif isinstance(data, Spectrum): if self._all_time: msg = QtWidgets.QMessageBox.warning(self, 'Mixed types', 'Timesignals and spectra cannot be edited at the same time.') return False else: self._all_time = False self._all_freq = True fid = data.copy() spec = self._temp_fft_time(fid.x, fid.y, self.baseline_box.isChecked()) x_len = data.x.size self.zf_spinbox.setMaximum(min(2**17//x_len, 3)) real_plt = PlotItem(x=fid.x, y=fid.y.real, pen=mkPen('b')) imag_plt = PlotItem(x=fid.x, y=fid.y.imag, pen=mkPen('r')) self.time_graph.addItem(imag_plt) self.time_graph.addItem(real_plt) real_plt_fft = PlotItem(x=spec[0], y=spec[1].real, pen=mkPen('b')) imag_plt_fft = PlotItem(x=spec[0], y=spec[1].imag, pen=mkPen('r')) self.freq_graph.addItem(imag_plt_fft) self.freq_graph.addItem(real_plt_fft) self.data.append(data) for p in [self._tmp_data_bl, self._tmp_data_ls]: p.append(data.y.copy()) for p in [self._tmp_data_zf, self._tmp_data_ap]: p.append((data.x, data.y.copy())) self._tmp_data_ph.append((data.x, data.y, spec[0], spec[1])) self.graphs.append((real_plt, imag_plt, real_plt_fft, imag_plt_fft)) return True @QtCore.pyqtSlot(name='on_baseline_box_clicked') def _update_bl(self): if self.baseline_box.isChecked(): for y in self._tmp_data_bl: self._temp_baseline(y) else: for i, d in enumerate(self.data): self._tmp_data_bl[i] = d.y.copy() self._update_shift() def _update_shift(self): if self.shift_box.isChecked(): if self.ls_combobox.currentIndex() == 0: num_points = self.ls_spinbox.value() is_time = False else: num_points = float(self.ls_lineedit.text()) is_time = True for i, y in enumerate(self._tmp_data_bl): self._tmp_data_ls[i] = self._temp_leftshift(self.data[i].dx, y, num_points, is_time) else: for i, y in enumerate(self._tmp_data_bl): self._tmp_data_ls[i] = y self._update_zf() def _update_zf(self): zf_padding = self.zf_spinbox.value() if self.zerofill_box.isChecked(): for i, y in enumerate(self._tmp_data_ls): self._tmp_data_zf[i] = self._temp_zerofill(self.data[i].x, y, zf_padding) else: for i, y in enumerate(self._tmp_data_ls): self._tmp_data_zf[i] = self.data[i].x, y self._update_apod() def _update_apod(self): if self.apod_box.isChecked(): model = self.apods[self.apodcombobox.currentIndex()] p = self._get_parameter() x_limit = np.inf, -np.inf y_limit = -np.inf for i, (x, y) in enumerate(self._tmp_data_zf): self._tmp_data_ap[i] = x, y * model.apod(x, *p) y_limit = max(y.real.max(), y_limit) x_limit = min(x_limit[0], x.min()), max(x_limit[1], x.max()) _x_apod = np.linspace(*x_limit, num=150) _y_apod = model.apod(_x_apod, *p) self.apod_graph.setData(x=_x_apod, y=y_limit * _y_apod) self.apod_graph.show() else: for i, (x, y) in enumerate(self._tmp_data_zf): self._tmp_data_ap[i] = x, y self.apod_graph.hide() self._update_phase() def _update_phase(self): if self.phase_box.isChecked(): pvt = float(self.pivot_lineedit.text()) self.pvt_line.show() ph0 = self.ph0_spinbox.value() ph1 = self.ph1_spinbox.value() for i, (x, y) in enumerate(self._tmp_data_ap): x_fft, y_fft = self._temp_fft_time(x, y, self.baseline_box.isChecked()) if ph0 != 0: y = self._temp_phase(x, y, ph0, 0, 0) y_fft = self._temp_phase(x, y_fft, ph0, ph1, pvt) elif ph1 != 0: y_fft = self._temp_phase(x, y_fft, ph0, ph1, pvt) self._tmp_data_ph[i] = x, y, x_fft, y_fft else: self.pvt_line.hide() for i, (x, y) in enumerate(self._tmp_data_ap): self._tmp_data_ph[i] = x, y, *self._temp_fft_time(x, y, self.baseline_box.isChecked()) self._update_plots() def _update_plots(self): for i, (x, y, xf, yf) in enumerate(self._tmp_data_ph): self.graphs[i][0].setData(x=x, y=y.real) self.graphs[i][1].setData(x=x, y=y.imag) self.graphs[i][2].setData(x=xf, y=yf.real) self.graphs[i][3].setData(x=xf, y=yf.imag) @staticmethod def _temp_baseline_time(y): y -= y[int(-0.12 * y.size):].mean() @staticmethod def _temp_baseline_freq(y): region = int(0.12 * y.size) y -= np.mean([y[-region:], y[:region]]) @staticmethod def _temp_phase(x: np.ndarray, y: np.ndarray, ph0: float, ph1: float, pvt: float) -> np.ndarray: phase_correction = np.exp(-1j * (ph0 + ph1 * (x - pvt) / x.max()) * pi / 180.) _y = y * phase_correction return _y @staticmethod def _temp_zerofill(x: np.ndarray, y: np.ndarray, num_padding: int) -> tuple[np.ndarray, np.ndarray]: length = x.size factor = 2**num_padding _y = np.r_[y, np.zeros((factor-1) * length)] _temp_x = np.arange(1, (factor-1) * length+1) * (x[1]-x[0]) + np.max(x) _x = np.r_[x, _temp_x] return _x, _y @staticmethod def _temp_leftshift(dx: np.ndarray, y: np.ndarray, points: float | int, is_time: bool) -> np.ndarray: if is_time: points = int(points//dx) _y = np.roll(y, -points) _y[-points-1:] = 0 return _y @staticmethod def _temp_fft_time(x: np.ndarray, y: np.ndarray, baseline: bool = False) -> tuple[np.ndarray, np.ndarray]: y_fft = fftshift(fft(y)) x_fft = fftshift(fftfreq(len(x), d=x[1]-x[0])) if baseline: QPreviewDialog._temp_baseline_freq(y_fft) return x_fft, y_fft @staticmethod def _temp_fft_freq(x: np.ndarray, y: np.ndarray, _=None): return x, y def move_line(self, evt): self.pivot_lineedit.setText(f'{evt.value():.5g}') @QtCore.pyqtSlot(int, name='on_apodcombobox_currentIndexChanged') def change_apodization(self, index: int) -> None: # delete old widgets self.eqn_label.setText(convert(self.apods[index].equation)) while self.widget_layout.count(): item = self.widget_layout.takeAt(0) if isinstance(item, FormWidget): item.disconnect() try: item.widget().deleteLater() except AttributeError: pass # set up parameter widgets for new model for k, v in enumerate(self.apods[index].params): widget = FormWidget(name=v) widget.value = 1 widget.valueChanged.connect(self._update_apod) self.widget_layout.addWidget(widget) self.widget_layout.addStretch() self._update_apod() def _get_parameter(self): p = [] for i in range(self.widget_layout.count()): item = self.widget_layout.itemAt(i) w = item.widget() try: p.append(w.value) except AttributeError: continue return p @QtCore.pyqtSlot(int, name='on_ls_combobox_currentIndexChanged') def change_ls(self, idx: int) -> None: self.ls_lineedit.setVisible(bool(idx)) self.ls_spinbox.setVisible(not bool(idx)) @QtCore.pyqtSlot(bool, name='on_ft_checkbox_stateChanged') def change_ft(self, state: bool): self.ph1_spinbox.setEnabled(state) self.pivot_lineedit.setEnabled(state) def cleanup(self): self.blockSignals(True) for line in self.graphs: for g in line: self.time_graph.removeItem(g) self.freq_graph.removeItem(g) del g self.time_graph.clear() self.freq_graph.clear() self._tmp_data_ap = [] self._tmp_data_bl = [] self._tmp_data_ls = [] self._tmp_data_ph = [] self._tmp_data_zf = [] self.data = [] self.graphs = [] self.freq_graph.removeItem(self.pvt_line) self.time_graph.removeItem(self.pvt_line) self.blockSignals(False) def get_value(self): edits = [(None,), (None,), (None,), (None,), (None,), (None,)] if self.baseline_box.isChecked(): edits[0] = (True,) if self.zerofill_box.isChecked(): edits[2] = (self.zf_spinbox.value(),) if self.shift_box.isChecked(): if self.ls_combobox.currentIndex() == 0: edits[1] = (self.ls_spinbox.value(), 'pts') else: edits[1] = (float(self.ls_lineedit.text()), 'time') if self.apod_box.isChecked(): edits[3] = (self._get_parameter(), self.apods[self.apodcombobox.currentIndex()]) if self.phase_box.isChecked(): edits[4] = (self.ph0_spinbox.value(), self.ph1_spinbox.value(), float(self.pivot_lineedit.text())) if self.ft_box.isChecked(): edits[5] = (self.phase_before_button.isChecked(),) return edits def exec(self): self._prepare_ui() return super().exec() def _prepare_ui(self): """Stuff we have to do before showing the window but after all the data was added""" vb = self.freq_graph.getPlotItem().getViewBox() vb.disableAutoRange(axis=vb.YAxis) vb = self.time_graph.getPlotItem().getViewBox() vb.disableAutoRange(axis=vb.YAxis) self.zerofill_box.setVisible(self._all_time) self.apod_box.setVisible(self._all_time) self.shift_box.setVisible(self._all_time) self.time_graph.setVisible(self._all_time) self.logtime_widget.setVisible(self._all_time) @QtCore.pyqtSlot(int, name='on_logx_time_stateChanged') @QtCore.pyqtSlot(int, name='on_logy_time_stateChanged') @QtCore.pyqtSlot(int, name='on_logx_freq_stateChanged') @QtCore.pyqtSlot(int, name='on_logy_freq_stateChanged') def set_log(self, state: int): switch = { self.logx_time: lambda _x: self.time_graph.setLogMode(x=_x), self.logy_time: lambda _x: self.time_graph.setLogMode(y=_x), self.logx_freq: lambda _x: self.freq_graph.setLogMode(x=_x), self.logy_freq: lambda _x: self.freq_graph.setLogMode(y=_x), }[self.sender()] switch(state == QtCore.Qt.Checked) vb = self.freq_graph.getPlotItem().getViewBox() vb.disableAutoRange(axis=vb.YAxis) vb = self.time_graph.getPlotItem().getViewBox() vb.disableAutoRange(axis=vb.YAxis) self._temp_baseline = self._temp_baseline_time if self._all_time else self._temp_baseline_freq self._temp_fft = self._temp_fft_time if self._all_time else self._temp_fft_freq self.freq_graph.setVisible(self._all_time) if self._all_freq: self.time_graph.addItem(self.pvt_line) else: self.freq_graph.addItem(self.pvt_line)