import numpy as np
from pyqtgraph import (
    InfiniteLine,
    ErrorBarItem,
    LinearRegionItem, mkBrush,
    mkColor, mkPen,
    PlotDataItem,
    LegendItem,
)

from nmreval.lib.colors import BaseColor, Colors
from nmreval.lib.lines import LineStyle
from nmreval.lib.symbols import SymbolStyle

from ..Qt import QtCore, QtGui

"""
Subclasses of pyqtgraph items, mostly to take care of log-scaling. 
pyqtgraph looks for function "setLogMode" for logarithmic axes, so needs to be implemented.
"""


class LogInfiniteLine(InfiniteLine):
    def __init__(self, **kwargs):
        self.logmode = [False, False]

        super().__init__(**kwargs)

    def setLogMode(self, xmode, ymode):
        """
        Does only work for vertical and horizontal lines
        """
        if self.logmode == [xmode, ymode]:
            return

        new_p = list(self.p[:])
        if self.logmode[0] != xmode:
            if xmode:
                new_p[0] = np.log10(new_p[0]+np.finfo(float).eps)
            else:
                new_p[0] = 10**new_p[0]

        if self.logmode[1] != ymode:
            if ymode:
                new_p[1] = np.log10(new_p[1]+np.finfo(float).eps)
            else:
                new_p[1] = 10**new_p[1]

        self.logmode = [xmode, ymode]

        if np.all(np.isfinite(new_p)):
            self.setPos(new_p)
        else:
            self.setPos(self.p)
            self.sigPositionChanged.emit(self)

    def setValue(self, v):
        if isinstance(v, QtCore.QPointF):
            v = [v.x(), v.y()]

        with np.errstate(divide='ignore'):
            if isinstance(v, (list, tuple)):
                for i in [0, 1]:
                    if self.logmode[i]:
                        v[i] = np.log10(v[i]+np.finfo(float).eps)
            else:
                if self.angle == 90:
                    if self.logmode[0]:
                        v = [np.log10(v+np.finfo(float).eps), 0]
                    else:
                        v = [v, 0]
                elif self.angle == 0:
                    if self.logmode[1]:
                        v = [0, np.log10(v+np.finfo(float).eps)]
                    else:
                        v = [0, v]
                else:
                    raise ValueError('LogInfiniteLine: Diagonal lines need two values')

        self.setPos(v)

    def value(self):
        p = self.getPos()
        if self.angle == 0:
            return 10**p[1] if self.logmode[1] else p[1]
        elif self.angle == 90:
            return 10**p[0] if self.logmode[0] else p[0]
        else:
            if self.logmode[0]:
                p[0] = 10**p[0]
            if self.logmode[1]:
                p[1] = 10**p[1]
            return p


class ErrorBars(ErrorBarItem):
    def __init__(self, **opts):
        self.log = [False, False]

        opts['xData'] = opts.get('x', None)
        opts['yData'] = opts.get('y', None)
        opts['topData'] = opts.get('top', None)
        opts['bottomData'] = opts.get('bottom', None)

        super().__init__(**opts)

    def setLogMode(self, x_mode, y_mode):
        if self.log == [x_mode, y_mode]:
            return

        self._make_log_scale(x_mode, y_mode)

        self.log[0] = x_mode
        self.log[1] = y_mode

        super().setData()

    def setData(self, **opts):
        self.opts.update(opts)

        self.opts['xData'] = opts.get('x', self.opts['xData'])
        self.opts['yData'] = opts.get('y', self.opts['yData'])
        self.opts['topData'] = opts.get('top', self.opts['topData'])
        self.opts['bottomData'] = opts.get('bottom', self.opts['bottomData'])

        if any(self.log):
            self._make_log_scale(*self.log)

        super().setData()

    def _make_log_scale(self, x_mode, y_mode):
        _x = self.opts['xData']
        _xmask = np.logical_not(np.isnan(_x))

        if x_mode:
            with np.errstate(all='ignore'):
                _x = np.log10(_x)
                _xmask = np.logical_not(np.isnan(_x))

        _y = self.opts['yData']
        _ymask = np.ones(_y.size, dtype=bool)
        _top = self.opts['topData']
        _bottom = self.opts['bottomData']

        if y_mode:
            with np.errstate(all='ignore'):
                logtop = np.log10(self.opts['topData']+_y)
                logbottom = np.log10(_y-self.opts['bottomData'])

                _y = np.log10(_y)
                _ymask = np.logical_not(np.isnan(_y))

                logbottom[logbottom == -np.inf] = _y[logbottom == -np.inf]
                _bottom = np.nan_to_num(np.maximum(_y-logbottom, 0))
                logtop[logtop == -np.inf] = _y[logtop == -np.inf]
                _top = np.nan_to_num(np.maximum(logtop-_y, 0))

        _mask = np.logical_and(_xmask, _ymask)

        self.opts['x'] = _x[_mask]
        self.opts['y'] = _y[_mask]
        self.opts['top'] = _top[_mask]
        self.opts['bottom'] = _bottom[_mask]


class PlotItem(PlotDataItem):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.opts['linecolor'] = (0, 0, 0)
        self.opts['symbolcolor'] = (0, 0, 0)

        if self.opts['pen'] is not None:
            pen = self.opts['pen']
            if isinstance(pen, tuple):
                self.opts['linecolor'] = pen
                self.opts['pen'] = mkPen(color=pen)
            else:
                c = pen.color()
                self.opts['linecolor'] = c.red(), c.green(), c.blue()

        if self.symbol != SymbolStyle.No:
            brush = self.opts['symbolBrush']
            if isinstance(brush, tuple):
                self.opts['symbolcolor'] = brush
            elif isinstance(brush, str):
                self.opts['symbolcolor'] = int(f'0x{brush[1:3]}', 16), int(f'0x{brush[3:5]}', 16), int(f'0x{brush[5:7]}', 16)
            else:
                c = brush.color()
                self.opts['symbolcolor'] = c.red(), c.green(), c.blue()

    def __getitem__(self, item):
        return self.opts.get(item, None)

    @property
    def symbol(self):
        return SymbolStyle.from_str(self.opts['symbol'])

    @property
    def symbolcolor(self):
        sc = self.opts['symbolcolor']
        if isinstance(sc, tuple):
            return Colors(sc)
        elif isinstance(sc, str):
            return Colors.from_str(sc)
        else:
            return sc

    @property
    def symbolsize(self):
        return self.opts['symbolSize']

    @property
    def linestyle(self) -> LineStyle:
        pen = self.opts['pen']
        if pen is None:
            return LineStyle.No
        else:
            return LineStyle(pen.style())

    @property
    def linewidth(self) -> float:
        pen = self.opts['pen']
        if pen is None:
            return 1.
        else:
            return pen.widthF()

    @property
    def linecolor(self) -> Colors:
        lc = self.opts['linecolor']
        if isinstance(lc, tuple):
            return Colors(lc)
        elif isinstance(lc, str):
            return Colors.from_str(lc)
        else:
            return lc

    def updateItems(self, styleUpdate=True):
        """
        We override this function so that curves with nan/inf values can be displayed.
        Newer versions close this bug differently (https://github.com/pyqtgraph/pyqtgraph/pull/1058)
        but this works somewhat.
        """

        curveArgs = {}
        for k, v in [('pen', 'pen'), ('shadowPen', 'shadowPen'), ('fillLevel', 'fillLevel'),
                     ('fillOutline', 'fillOutline'), ('fillBrush', 'brush'), ('antialias', 'antialias'),
                     ('connect', 'connect'), ('stepMode', 'stepMode')]:
            curveArgs[v] = self.opts[k]

        scatterArgs = {}
        for k, v in [('symbolPen', 'pen'), ('symbolBrush', 'brush'), ('symbol', 'symbol'), ('symbolSize', 'size'),
                     ('data', 'data'), ('pxMode', 'pxMode'), ('antialias', 'antialias')]:
            if k in self.opts:
                scatterArgs[v] = self.opts[k]

        x, y = self.getData()
        if x is None:
            x = []
        if y is None:
            y = []

        if curveArgs['pen'] is not None or (curveArgs['brush'] is not None and curveArgs['fillLevel'] is not None):
            is_finite = np.isfinite(x) & np.isfinite(y)
            all_finite = np.all(is_finite)
            if not all_finite:
                # remove all bad values
                x = x[is_finite]
                y = y[is_finite]
            curveArgs['connect'] = 'all'
            self.curve.setData(x=x, y=y, **curveArgs)
            self.curve.show()
        else:
            self.curve.hide()

        if scatterArgs['symbol'] is not None:
            if self.opts.get('stepMode', False) is True:
                x = 0.5 * (x[:-1] + x[1:])
            self.scatter.setData(x=x, y=y, **scatterArgs)
            self.scatter.show()
        else:
            self.scatter.hide()

    def set_symbol(self, *, symbol=None, size=None, color=None):
        if symbol is not None:
            if isinstance(symbol, int):
                self.setSymbol(SymbolStyle(symbol).to_str())
            elif isinstance(symbol, SymbolStyle):
                self.setSymbol(symbol.to_str())
            else:
                self.setSymbol(symbol)

        if color is not None:
            self.set_color(color, symbol=True)

        if size is not None:
            self.setSymbolSize(size)

    def set_color(self, color, symbol=False, line=False):
        if isinstance(color, BaseColor):
            color = color.rgb()
        elif isinstance(color, QtGui.QColor):
            color = color.getRgb()[:3]

        if symbol:
            self.setSymbolBrush(mkBrush(color))
            self.setSymbolPen(mkPen(color=color))
            self.opts['symbolcolor'] = color

        if line:
            pen = self.opts['pen']
            self.opts['linecolor'] = color
            if pen is not None:
                pen.setColor(mkColor(color))
            self.opts['pen'] = pen
            self.updateItems()

    def set_line(self, *, style=None, width=None, color=None):
        pen = self.opts['pen']
        if pen is None:
            pen = mkPen(style=QtCore.Qt.NoPen)

        if width is not None:
            pen.setWidthF(width)
        if style is not None:
            if isinstance(style, LineStyle):
                style = style.value

            pen.setStyle(style)

        self.opts['pen'] = pen
        self.updateItems()

        if color is not None:
            self.set_color(color, symbol=False, line=True)

    def get_data_opts(self) -> dict:
        x, y = self.xData, self.yData
        if (x is None) or (len(x) == 0):
            return {}

        opts = self.opts
        item_dic = {
            'x': x,
            'y': y,
            'name': opts.get('name', ''),
            'symbolsize': opts['symbolSize'],
        }

        if opts['symbol'] is None:
            item_dic['symbol'] = SymbolStyle.No
            item_dic['symbolcolor'] = None
        else:
            item_dic['symbol'] = SymbolStyle.from_str(opts['symbol'])
            item_dic['symbolcolor'] = opts['symbolcolor']

        pen = opts['pen']
        if pen is not None:
            item_dic['linestyle'] = LineStyle(pen.style())
            item_dic['linecolor'] = opts['linecolor']
            item_dic['linewidth'] = pen.widthF()
        else:
            item_dic['linestyle'] = LineStyle.No
            item_dic['linecolor'] = None
            item_dic['linewidth'] = 0.0

        if item_dic['linecolor'] is None and item_dic['symbolcolor'] is None:
            item_dic['symbolcolor'] = Colors.Black.rgb()
        elif item_dic['linecolor'] is None:
            item_dic['linecolor'] = item_dic['symbolcolor']
        elif item_dic['symbolcolor'] is None:
            item_dic['symbolcolor'] = item_dic['linecolor']

        return item_dic


class RegionItem(LinearRegionItem):
    def __init__(self, *args, **kwargs):
        self.mode = kwargs.pop('mode', 'half')
        super().__init__(*args, **kwargs)

        self.logmode = False
        self.first = True
        if not hasattr(self, '_bounds') and hasattr(self, '_boundingRectCache'):
            self._bounds = self._boundingRectCache

        for l in self.lines:
            # higher z for borderlines improves chances that you can move it when multiple regions overlap
            l.setZValue(self.zValue() + 1)

    def setLogMode(self, xmode, _):
        if self.logmode == xmode:
            return

        if xmode:
            new_region = [np.log10(self.lines[0].value()), np.log10(self.lines[1].value())]

            if np.isnan(new_region[1]):
                new_region[1] = self.lines[1].value()

            if np.isnan(new_region[0]):
                new_region[0] = new_region[1]/10.

        else:
            new_region = [10**self.lines[0].value(), 10**self.lines[1].value()]

        self.logmode = xmode
        self.setRegion(new_region)

    def dataBounds(self, axis, frac=1.0, orthoRange=None):
        if axis == self._orientation_axis[self.orientation]:
            r = self.getRegion()
            if self.logmode:
                r = np.log10(r[0]), np.log10(r[1])
            return r
        else:
            return None

    def getRegion(self):
        region = super().getRegion()
        if self.logmode:
            return 10**region[0], 10**region[1]
        else:
            return region

    def setRegion(self, region, use_log=False):
        if self.logmode and use_log:
            region = np.log10(region[0]), np.log10(region[1])

        if not np.all(np.isfinite(region)):
            raise ValueError(f'Invalid region limits ({region[0]}, {region[1]})')
        else:
            super().setRegion(region)

    def boundingRect(self):
        # overwrite to draw correct rect in logmode

        br = self.viewRect()  # bounds of containing ViewBox mapped to local coords.

        rng = self.getRegion()
        if self.logmode:
            rng = np.log10(rng[0]), np.log10(rng[1])

        if self.orientation in ('vertical', LinearRegionItem.Vertical):
            br.setLeft(rng[0])
            br.setRight(rng[1])
            length = br.height()
            br.setBottom(br.top() + length * self.span[1])
            br.setTop(br.top() + length * self.span[0])
        else:
            br.setTop(rng[0])
            br.setBottom(rng[1])
            length = br.width()
            br.setRight(br.left() + length * self.span[1])
            br.setLeft(br.left() + length * self.span[0])

        br = br.normalized()

        if self._bounds != br:
            self._bounds = br
            self.prepareGeometryChange()

        return br


class LegendItemBlock(LegendItem):
    """
    Simple subclass that stops dragging legend outside of view
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layout.setContentsMargins(1, 1, 1, 1)

    def mouseDragEvent(self, ev):
        if ev.button() == QtCore.Qt.LeftButton:
            ev.accept()

            dpos = ev.pos() - ev.lastPos()

            upper_left = self.pos()
            lower_right = self.pos()
            lower_right.setX(lower_right.x() + self.width())
            lower_right.setY(lower_right.y() + self.height())

            vb_rect = self.parentItem().rect()

            # upper left and lower right corner must be inside viewbox
            if vb_rect.contains(upper_left + dpos) and vb_rect.contains(lower_right + dpos):
                self.autoAnchor(upper_left + dpos)
            else:
                self.autoAnchor(upper_left)