from __future__ import annotations

import itertools
import os
import uuid

from math import isnan
from pathlib import Path

from numpy import errstate, floor, log10
from pyqtgraph import GraphicsObject, getConfigOption, mkColor

from nmreval.utils.text import convert
from ..io.filedialog import FileDialog

from ..lib.pg_objects import LegendItemBlock, RegionItem
from ..Qt import QtCore, QtWidgets, QtGui
from .._py.graph import Ui_GraphWindow
from ..lib import make_action_icons
from ..lib.configurations import GraceMsgBox


class QGraphWindow(QtWidgets.QGraphicsView, Ui_GraphWindow):
    mousePositionChanged = QtCore.pyqtSignal(float, float)
    mouseDoubleClicked = QtCore.pyqtSignal()
    positionClicked = QtCore.pyqtSignal(tuple, bool)
    aboutToClose = QtCore.pyqtSignal(str)

    counter = itertools.count()

    def __init__(self, parent=None):
        super().__init__(parent=parent)
        self.setupUi(self)

        self._bgcolor = mkColor(getConfigOption('background'))
        self._fgcolor = mkColor(getConfigOption('foreground'))
        self._prev_colors = mkColor('k'), mkColor('w')

        self._init_gui()

        make_action_icons(self)

        self.id = str(uuid.uuid4())

        self.sets = []
        self.active = []

        self.real_plots = {}
        self.imag_plots = {}
        self.error_plots = {}

        self._special_needs = []
        self._external_items = []
        self.closable = True

        self.log = [False, False]

        self.scene = self.plotItem.scene()
        self.scene.sigMouseMoved.connect(self.move_mouse)

        self.checkBox.stateChanged.connect(lambda x: self.legend.setVisible(x == QtCore.Qt.Checked))
        self.label_button.toggled.connect(lambda x: self.label_widget.setVisible(x))
        self.limit_button.toggled.connect(lambda x: self.limit_widget.setVisible(x))
        self.gridbutton.toggled.connect(lambda x: self.graphic.showGrid(x=x, y=x))
        self.logx_button.toggled.connect(lambda x: self.set_logmode(xmode=x))
        self.logy_button.toggled.connect(lambda x: self.set_logmode(ymode=x))
        self.graphic.plotItem.vb.sigRangeChanged.connect(self.update_limits)
        self.listWidget.itemChanged.connect(self.show_legend)

        # reconnect "Export..." in context menu to our function
        self.scene.contextMenu[0].disconnect()
        self.scene.contextMenu[0].triggered.connect(self.export_dialog)

    def _init_gui(self):
        self.setWindowTitle('Graph ' + str(next(QGraphWindow.counter)))

        self.label_widget.hide()
        self.limit_widget.hide()
        self.listWidget.hide()
        self.checkBox.hide()

        self.plotItem = self.graphic.plotItem
        for orient in ['top', 'bottom', 'left', 'right']:
            self.plotItem.showAxis(orient)
            ax = self.plotItem.getAxis(orient)
            ax.enableAutoSIPrefix(False)
            if orient == 'top':
                ax.setStyle(showValues=False)
                ax.setHeight(10)
            elif orient == 'right':
                ax.setStyle(showValues=False)
                ax.setWidth(10)

        self.legend = LegendItemBlock(offset=(20, 20))
        self.legend.setParentItem(self.plotItem.vb)
        self.plotItem.legend = self.legend
        self.legend.setVisible(True)

        self.plotItem.setMenuEnabled(False, True)
        self.plotItem.ctrl.logXCheck.blockSignals(True)
        self.plotItem.ctrl.logYCheck.blockSignals(True)

        for lineedit in [self.xmin_lineedit, self.xmax_lineedit, self.ymin_lineedit, self.ymax_lineedit]:
            lineedit.setValidator(QtGui.QDoubleValidator())

    def __contains__(self, item: str):
        return item in self.sets

    def __iter__(self):
        return iter(self.active)

    def __len__(self):
        return len(self.active)

    def curves(self) -> tuple:
        for set_id in self.sets:
            if set_id in self.active:
                if self.real_button.isChecked():
                    if self.error_plots[set_id] is not None:
                        yield self.real_plots[set_id], self.error_plots[set_id]
                    else:
                        yield self.real_plots[set_id],

                if self.imag_button.isChecked() and self.imag_plots[set_id] is not None:
                    yield self.imag_plots[set_id],

    @property
    def title(self):
        return self.windowTitle()

    @title.setter
    def title(self, value):
        self.setWindowTitle(str(value))

    @property
    def ranges(self) -> tuple:
        r = self.plotItem.getViewBox().viewRange()
        for i in [0, 1]:
            if self.log[i]:
                r[i] = tuple([10**x for x in r[i]])
            else:
                r[i] = tuple(r[i])

        return tuple(r)

    def add(self, name: str | list, plots: list):
        if isinstance(name, str):
            name = [name]
            plots = [plots]

        for (real_plot, imag_plot, err_plot), n in zip(plots, name):
            toplevel = len(self.sets)
            self.sets.append(n)

            if real_plot:
                real_plot.setZValue(2*toplevel+1)
            if imag_plot:
                imag_plot.setZValue(2*toplevel+1)
            if err_plot:
                err_plot.setZValue(2*toplevel)

            self.real_plots[n] = real_plot
            self.imag_plots[n] = imag_plot
            self.error_plots[n] = err_plot

            list_item = QtWidgets.QListWidgetItem(real_plot.opts.get('name', ''))
            list_item.setData(QtCore.Qt.UserRole, n)
            list_item.setFlags(QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable | QtCore.Qt.ItemIsUserCheckable)
            list_item.setCheckState(QtCore.Qt.Checked)
            self.listWidget.addItem(list_item)

        self.show_item(name)

    def remove(self, name: str | list):
        if isinstance(name, str):
            name = [name]

        for n in name:
            self.sets.remove(n)

            for plot in [self.real_plots, self.imag_plots, self.error_plots]:
                self.graphic.removeItem(plot[n])

            if n in self.active:
                self.active.remove(n)

        # remove from label list
        self.listWidget.blockSignals(True)

        for i in range(self.listWidget.count()-1, 0, -1):
            item = self.listWidget.item(i)
            if item.data(QtCore.Qt.UserRole) in name:
                self.listWidget.takeItem(i)

        self.listWidget.blockSignals(False)

        self._update_zorder()
        self.show_legend()

    def move_sets(self, sets: list, position: int):
        move_plots = []
        move_items = []

        self.listWidget.blockSignals(True)

        for s in sets:
            idx = self.sets.index(s)
            move_plots.append(self.sets.pop(idx))
            move_items.append(self.listWidget.takeItem(idx))

        if position == -1:
            self.sets.extend(move_plots)
            for it in move_items:
                self.listWidget.addItem(it)
        else:
            self.sets = self.sets[:position] + move_plots + self.sets[position:]
            for it in move_items[::-1]:
                self.listWidget.insertItem(position, it)

        self.listWidget.blockSignals(False)
        self._update_zorder()

    def show_item(self, idlist: list):
        if len(self.sets) == 0:
            return

        for a in idlist:
            if a not in self.active:
                self.active.append(a)

            for (bttn, plot_dic) in [
                (self.real_button, self.real_plots),
                (self.imag_button, self.imag_plots),
                (self.error_button, self.error_plots),
            ]:
                if bttn.isChecked():
                    item = plot_dic[a]
                    if (item is not None) and (item not in self.graphic.items()):
                        self.graphic.addItem(item)

        self.show_legend()

    def hide_item(self, idlist: list):
        if len(self.sets) == 0:
            return

        for r in idlist:
            if r in self.active:
                self.active.remove(r)

            for plt in [self.real_plots, self.imag_plots, self.error_plots]:
                item = plt[r]
                if item in self.graphic.items():
                    self.graphic.removeItem(item)

    @QtCore.pyqtSlot(bool, name='on_imag_button_toggled')
    @QtCore.pyqtSlot(bool, name='on_real_button_toggled')
    def set_imag_visible(self, visible: bool):
        if self.sender() == self.real_button:
            plots = self.real_plots
            if self.error_button.isChecked() and not visible:
                self.error_button.setChecked(False)
        else:
            plots = self.imag_plots

        if visible:
            func = self.graphic.addItem
        else:
            func = self.graphic.removeItem

        for a in self.active:
            item = plots[a]
            if item is not None:
                func(item)

        self.show_legend()

    @QtCore.pyqtSlot(bool, name='on_error_button_toggled')
    def show_errorbar(self, visible: bool):
        if visible and not self.real_button.isChecked():
            # no errorbars without points
            self.error_button.blockSignals(True)
            self.error_button.setChecked(False)
            self.error_button.blockSignals(False)
            return

        if visible:
            for a in self.active:
                item = self.error_plots[a]
                if (item is not None) and (item not in self.graphic.items()):
                    self.graphic.addItem(item)
        else:
            for a in self.active:
                item = self.error_plots[a]
                if (item is not None) and (item in self.graphic.items()):
                    self.graphic.removeItem(item)

    def add_external(self, item):
        if isinstance(item, RegionItem) and item.first:
            # Give regions nice values on first addition to a graph
            x, _ = self.ranges

            if item.mode == 'mid':
                onset = item.getRegion()[0]
                if self.log[0]:
                    delta = log10(x[1]/x[0])/20
                    span = (onset / 10**delta , onset * 10**delta)
                else:
                    delta = x[1]-x[0]
                    span = (onset-delta/20, onset + delta/20)
            elif item.mode == 'half':
                span = (0.75*x[0]+0.25*x[1], 0.25*x[0]+0.75*x[1])
            else:
                span = item.getRegion()

            item.setRegion(span)
            item.first = False

        if item in self.graphic.items():
            return False

        if not hasattr(item, 'setLogMode'):
            self._special_needs.append(item)

        self._external_items.append(item)
        self.graphic.addItem(item)
        item.setZValue(1000)

        return True

    @QtCore.pyqtSlot(GraphicsObject)
    def remove_external(self, item):
        if item in self._external_items:
            self._external_items.remove(item)

        if item in self._special_needs:
            self._special_needs.remove(item)

        if item not in self.graphic.items():
            return False

        self.graphic.removeItem(item)

        return True

    def closeEvent(self, evt: QtGui.QCloseEvent):
        if not self.closable:
            evt.ignore()
            return

        res = QtWidgets.QMessageBox.Yes
        if len(self.sets) != 0:
            res = QtWidgets.QMessageBox.question(self, 'Plot not empty', 'Graph is not empty. Deleting with all data?',
                                                 QtWidgets.QMessageBox.Yes, QtWidgets.QMessageBox.No)

        if res == QtWidgets.QMessageBox.Yes:
            self.aboutToClose.emit(self.id)
            evt.accept()
        else:
            evt.ignore()

    def move_mouse(self, evt):
        vb = self.plotItem.getViewBox()
        if self.plotItem.sceneBoundingRect().contains(evt):
            pos = vb.mapSceneToView(evt)
            if self.log[0]:
                try:
                    _x = 10**(pos.x())
                except OverflowError:
                    _x = pos.x()
            else:
                _x = pos.x()

            if self.log[1]:
                try:
                    _y = 10**(pos.y())
                except OverflowError:
                    _y = pos.y()
            else:
                _y = pos.y()
            self.mousePositionChanged.emit(_x, _y)

    @QtCore.pyqtSlot(name='on_title_lineedit_returnPressed')
    @QtCore.pyqtSlot(name='on_xaxis_linedit_returnPressed')
    @QtCore.pyqtSlot(name='on_yaxis_linedit_returnPressed')
    def labels_changed(self):
        label = {self.title_lineedit: 'title', self.xaxis_linedit: 'x', self.yaxis_linedit: 'y'}[self.sender()]
        self.set_label(**{label: self.sender().text()})

    def set_label(self, x=None, y=None, title=None):
        if title is not None:
            self.plotItem.setTitle(convert(title, old='tex', new='html'), **{'size': '10pt', 'color': self._fgcolor})

        if x is not None:
            self.plotItem.setLabel('bottom', convert(x, old='tex', new='html'),
                                   **{'font-size': '10pt', 'color': self._fgcolor.name()})

        if y is not None:
            self.plotItem.setLabel('left', convert(y, old='tex', new='html'),
                                   **{'font-size': '10pt', 'color': self._fgcolor.name()})

    def set_logmode(self, xmode: bool = None, ymode: bool = None):
        r = self.ranges

        if xmode is None:
            xmode = self.plotItem.ctrl.logXCheck.isChecked()
        else:
            self.plotItem.ctrl.logXCheck.setCheckState(xmode)

        if ymode is None:
            ymode = self.plotItem.ctrl.logYCheck.isChecked()
        else:
            self.plotItem.ctrl.logYCheck.setCheckState(ymode)

        self.log = [xmode, ymode]

        for item in self._special_needs:
            item.logmode[0] = self.log[:]

        self.plotItem.updateLogMode()

        self.plotItem.enableAutoRange()

    def enable_picking(self, enabled: bool):
        if enabled:
            self.scene.sigMouseClicked.connect(self.position_picked)
        else:
            try:
                self.scene.sigMouseClicked.disconnect()
            except TypeError:
                pass

    def position_picked(self, evt):
        vb = self.graphic.plotItem.vb

        if self.graphic.plotItem.sceneBoundingRect().contains(evt.scenePos()) and evt.button() == 1:
            pos = vb.mapSceneToView(evt.scenePos())
            _x, _y = pos.x(), pos.y()

            if self.log[0]:
                _x = 10**_x

            if self.log[1]:
                _y = 10**_y

            self.positionClicked.emit((_x, _y), evt.double())

    @QtCore.pyqtSlot(name='on_apply_button_clicked')
    def set_range(self, x: tuple = None, y: tuple = None):
        if x is None:
            x = float(self.xmin_lineedit.text()), float(self.xmax_lineedit.text())
            x = min(x), max(x)

        if y is None:
            y = float(self.ymin_lineedit.text()), float(self.ymax_lineedit.text())
            y = min(y), max(y)

        for log, xy, func in zip(self.log, (x, y), (self.graphic.setXRange, self.graphic.setYRange)):
            if log:
                with errstate(all='ignore'):
                    xy = [log10(val) for val in xy]

                if isnan(xy[1]):
                    xy = [-1, 1]
                elif isnan(xy[0]):
                    xy[0] = xy[1]-4

            func(xy[0], xy[1], padding=0)

    @QtCore.pyqtSlot(object)
    def update_limits(self, _):
        r = self.ranges
        self.xmin_lineedit.setText('%.5g' % r[0][0])
        self.xmax_lineedit.setText('%.5g' % r[0][1])

        self.ymin_lineedit.setText('%.5g' % r[1][0])
        self.ymax_lineedit.setText('%.5g' % r[1][1])

    def _update_zorder(self):
        for i, sid in enumerate(self.sets):
            plt = self.real_plots[sid]
            if plt.zValue() != 2*i+1:
                plt.setZValue(2*i+1)
                if self.imag_plots[sid] is not None:
                    self.imag_plots[sid].setZValue(2*i+1)
                if self.error_plots[sid] is not None:
                    self.error_plots[sid].setZValue(2*i)

        self.show_legend()

    @QtCore.pyqtSlot(bool, name='on_legend_button_toggled')
    def show_legend_item_list(self, visible: bool):
        self.listWidget.setVisible(visible)
        self.checkBox.setVisible(visible)

    def update_legend(self, sid, name):
        self.listWidget.blockSignals(True)

        for i in range(self.listWidget.count()):
            item = self.listWidget.item(i)
            if item.data(QtCore.Qt.UserRole) == sid:
                item.setText(convert(name, old='tex', new='html'))

        self.listWidget.blockSignals(False)
        self.show_legend()

    def show_legend(self):
        if not self.legend.isVisible():
            return

        self.legend.clear()

        for i, sid in enumerate(self.sets):
            item = self.real_plots[sid]
            other_item = self.imag_plots[sid]
            # should legend be visible? is either real part or imaginary part shown?
            if self.listWidget.item(i).checkState():
                if item in self.graphic.items():
                    self.legend.addItem(item, convert(item.opts.get('name', ''), old='tex', new='html'))
                elif other_item in self.graphic.items():
                    self.legend.addItem(other_item, convert(other_item.opts.get('name', ''), old='tex', new='html'))

    def export_dialog(self, path=None):
        filters = 'All files (*.*);;AGR (*.agr);;SVG (*.svg);;PDF (*.pdf)'
        for imgformat in QtGui.QImageWriter.supportedImageFormats():
            str_format = imgformat.data().decode('utf-8')
            filters += ';;' + str_format.upper() + ' (*.' + str_format + ')'

        if path is None:
            path = ''
        outfile = None
        f = FileDialog(caption='Export graphic', directory=str(path), filter=filters, mode='save')
        f.setOption(FileDialog.DontConfirmOverwrite)
        mode = f.exec()
        if mode == QtWidgets.QDialog.Accepted:
            outfile = f.save_file()
        if outfile:
            self.export(outfile)

    def export(self, outfile: Path):
        suffix = outfile.suffix
        if suffix == '':
            QtWidgets.QMessageBox.warning(self, 'No file extension',
                                          'No file extension found, graphic was not saved.')
            return

        if suffix == '.agr':
            res = 0
            if outfile.exists():
                res = GraceMsgBox(outfile, parent=self).exec()
                if res == -1:
                    return

            opts = self.export_graphics()

            from ..io.exporters import GraceExporter
            if res == 0:
                mode = 'w'
            elif res == 1:
                mode = 'a'
            else:
                mode = res-2

            GraceExporter(opts).export(outfile, mode=mode)

        else:
            if os.path.exists(outfile):
                if QtWidgets.QMessageBox.warning(self, 'Export graphic',
                                                 f'{os.path.split(outfile)[1]} already exists.\n'
                                                 f'Do you REALLY want to replace it?',
                                                 QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
                                                 QtWidgets.QMessageBox.No) == QtWidgets.QMessageBox.No:
                    return

            bg_color = self._bgcolor
            fg_color = self._fgcolor
            self.set_color(foreground='k', background='w')

            if suffix == '.pdf':
                from ..io.exporters import PDFPrintExporter
                PDFPrintExporter(self.graphic).export(str(outfile))

            elif suffix == '.svg':
                from pyqtgraph.exporters import SVGExporter
                SVGExporter(self.scene).export(str(outfile))

            else:
                from pyqtgraph.exporters import ImageExporter

                ImageExporter(self.scene).export(str(outfile))

            self.set_color(foreground=fg_color, background=bg_color)

    def export_graphics(self) -> dict:
        dic = self.get_state()
        dic['items'] = []

        in_legend = []

        for item in self.curves():
            plot_item = item[0]
            legend_shown = False
            for sample, _ in self.legend.items:
                if sample.item is plot_item:
                    legend_shown = True
                    break
            in_legend.append(legend_shown)
            try:
                item_dic = plot_item.get_data_opts()
            except Exception as e:
                print(f'{item} could not exported because {e.args}')
                continue

            if len(item) == 2:
                # plot can show errorbars
                item_dic['yerr'] = item[1].opts['topData']

            if item_dic:
                dic['items'].append(item_dic)

        for item in self._external_items:
            try:
                dic['items'].append(item.get_data_opts())
            except Exception as e:
                print(f'{item} could not be exported because {e.args}')
                continue

            in_legend.append(False)

        dic['in_legend'] = in_legend

        return dic

    def get_state(self) -> dict:
        dic = {
            'id': self.id,
            'limits': (self.ranges[0], self.ranges[1]),
            'ticks': (),
            'labels': (self.plotItem.getAxis('bottom').labelText,
                       self.plotItem.getAxis('left').labelText,
                       self.plotItem.titleLabel.text,
                       self.title),
            'log': self.log,
            'grid': self.gridbutton.isChecked(),
            'legend': self.legend.isVisible(),
            'plots': (self.real_button.isChecked(), self.imag_button.isChecked(), self.error_button.isChecked()),
            'children': self.sets,
            'active': self.active,
        }

        in_legend = []
        for i in range(self.listWidget.count()):
            in_legend.append(bool(self.listWidget.item(i).checkState()))
        dic['in_legend'] = in_legend

        # bottomLeft gives top left corner
        l_topleft = self.plotItem.vb.itemBoundingRect(self.legend).bottomLeft()
        legend_origin = [l_topleft.x(), l_topleft.y()]
        for i in [0, 1]:
            if self.log[i]:
                legend_origin[i] = 10**legend_origin[i]
        dic['legend_pos'] = legend_origin

        for i, ax in enumerate(['bottom', 'left']):
            if self.log[i]:
                major = 10
                minor = 9
            else:
                vmin, vmax = dic['limits'][i][0], dic['limits'][i][1]
                dist = vmax - vmin
                scale = 10**floor(log10(abs(dist)))
                steps = [0.1, 0.2, 0.25, 0.5, 1., 2., 2.5, 5., 10., 20., 50., 100.]
                for step_i in steps:
                    if dist / step_i / scale <= 10:
                        break
                major = step_i * scale
                minor = 1

            dic['ticks'] += (major, minor),

        return dic

    @staticmethod
    def set_state(state):
        graph = QGraphWindow()
        graph.id = state.get('id', graph.id)

        graph.plotItem.setLabel('bottom', state['labels'][0], **{'font-size': '10pt', 'color': graph._fgcolor.name()})
        graph.plotItem.setLabel('left', state['labels'][1], **{'font-size': '10pt', 'color': graph._fgcolor.name()})
        graph.plotItem.setTitle(state['labels'][2], **{'size': '10pt', 'color': graph._fgcolor.name()})
        graph.setWindowTitle(state['labels'][3])

        graph.graphic.showGrid(x=state['grid'], y=state['grid'])

        graph.checkBox.setCheckState(QtCore.Qt.Checked if state['legend'] else QtCore.Qt.Unchecked)

        graph.real_button.setChecked(state['plots'][0])
        graph.imag_button.setChecked(state['plots'][1])
        graph.error_button.setChecked(state['plots'][2])

        graph.set_range(x=state['limits'][0], y=state['limits'][1])
        graph.logx_button.setChecked(state['log'][0])
        graph.logy_button.setChecked(state['log'][1])

        return graph

    def set_color(self, foreground=None, background=None):
        if background is not None:
            self._bgcolor = mkColor(background)
            self.graphic.setBackground(self._bgcolor)
            self.legend.setBrush(self._bgcolor)

        if foreground is not None:
            self._fgcolor = mkColor(foreground)

        for ax in ['left', 'bottom']:
            pen = self.plotItem.getAxis(ax).pen()
            pen.setColor(self._fgcolor)

            self.plotItem.getAxis(ax).setPen(pen)
            self.plotItem.getAxis(ax).setTextPen(pen)

        self.legend.setLabelTextColor(self._fgcolor)
        if self.legend.isVisible():
            self.show_legend()

        title = self.plotItem.titleLabel.text
        if title is not None:
            self.plotItem.setTitle(title, **{'size': '10pt', 'color': self._fgcolor})

        x = self.plotItem.getAxis('bottom').labelText
        if x is not None:
            self.plotItem.setLabel('bottom', x, **{'font-size': '10pt', 'color': self._fgcolor.name()})

        y = self.plotItem.getAxis('left').labelText
        if y is not None:
            self.plotItem.setLabel('left', y, **{'font-size': '10pt', 'color': self._fgcolor.name()})

    @QtCore.pyqtSlot(bool, name='on_bwbutton_toggled')
    def change_background(self, _):
        temp = self._fgcolor, self._bgcolor
        self.set_color(foreground=self._prev_colors[0], background=self._prev_colors[1])
        self._prev_colors = temp