from __future__ import annotations

from nmreval.utils.text import convert

from ..Qt import QtWidgets, QtCore, QtGui
from .._py.fitfuncwidget import Ui_FormFit
from ..lib.forms import SelectionWidget
from .fit_forms import FitModelWidget


class QFitParameterWidget(QtWidgets.QWidget, Ui_FormFit):
    value_requested = QtCore.pyqtSignal(int)

    def __init__(self, parent=None):
        super().__init__(parent=parent)
        self.setupUi(self)

        self.func = None
        self.func_idx = None
        self.max_width = QtCore.QSize(0, 0)
        self.global_parameter = []
        self.data_parameter = []
        self.glob_values = None
        self.data_values = {}

        self.scrollwidget.setLayout(QtWidgets.QVBoxLayout())
        self.scrollwidget2.setLayout(QtWidgets.QVBoxLayout())

    def eventFilter(self, src: QtCore.QObject, evt: QtCore.QEvent):
        if isinstance(evt, QtGui.QKeyEvent):
            if (evt.key() == QtCore.Qt.Key_Right) and \
                    (evt.modifiers() == QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier):
                self.change_single_parameter(src.value, sender=src)
                self.select_next_preview(1)

                return True

            elif (evt.key() == QtCore.Qt.Key_Left) and \
                    (evt.modifiers() == QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier):
                self.change_single_parameter(src.value, sender=src)
                self.select_next_preview(-1)

                return True

        return super().eventFilter(src, evt)

    def load(self, data):
        self.comboBox.blockSignals(True)
        while self.comboBox.count():
            self.comboBox.removeItem(0)

        for sid, name in data:
            self.comboBox.addItem(name, userData=sid)
            self._make_parameter(sid)
        self.comboBox.blockSignals(False)
        self.change_data(0)

    def set_function(self, func, idx):
        self.func = func
        self.func_idx = idx

        self.glob_values = [1] * len(func.params)

        for k, v in enumerate(func.params):
            name = convert(v)
            widgt = FitModelWidget(label=name, parent=self.scrollwidget)
            widgt.parameter_pos = k
            widgt.func_idx = idx
            try:
                widgt.set_bounds(*func.bounds[k], False)
            except (AttributeError, IndexError):
                pass

            size = widgt.parametername.sizeHint()
            if self.max_width.width() < size.width():
                self.max_width = size

            widgt.state_changed.connect(self.make_global)
            widgt.value_requested.connect(self.look_for_value)
            widgt.value_changed.connect(self.change_global_parameter)

            self.global_parameter.append(widgt)
            self.scrollwidget.layout().addWidget(widgt)

            widgt2 = ParameterSingleWidget(name=name, parent=self.scrollwidget2)
            widgt2.valueChanged.connect(self.change_single_parameter)
            widgt2.removeSingleValue.connect(self.change_single_parameter)
            widgt2.installEventFilter(self)
            self.scrollwidget2.layout().addWidget(widgt2)
            self.data_parameter.append(widgt2)

        for w1, w2 in zip(self.global_parameter, self.data_parameter):
            w1.parametername.setFixedSize(self.max_width)
            w1.checkBox.setFixedSize(self.max_width)
            w2.label.setFixedSize(self.max_width)

        if hasattr(func, 'choices') and func.choices is not None:
            cbox = func.choices
            for c in cbox:
                widgt = SelectionWidget(*c)
                widgt.selectionChanged.connect(self.change_global_choice)
                self.global_parameter.append(widgt)
                self.glob_values.append(widgt.value)
                self.scrollwidget.layout().addWidget(widgt)

                widgt2 = SelectionWidget(*c)
                widgt2.selectionChanged.connect(self.change_single_choice)
                self.data_parameter.append(widgt2)
                self.scrollwidget2.layout().addWidget(widgt2)

        for i in range(self.comboBox.count()):
            self._make_parameter(self.comboBox.itemData(i))

        self.scrollwidget.layout().addStretch(1)
        self.scrollwidget2.layout().addStretch(1)

    def set_links(self, parameter):
        for w in self.global_parameter:
            if isinstance(w, FitModelWidget):
                w.add_links(parameter)

    @QtCore.pyqtSlot(str)
    def change_global_parameter(self, value: str, idx: int = None):
        if idx is None:
            idx = self.global_parameter.index(self.sender())

        # self.glob_values[idx] = float(value)
        self.glob_values[idx] = value
        if self.data_values[self.comboBox.currentData()][idx] is None:
            self.data_parameter[idx].blockSignals(True)
            # self.data_parameter[idx].value = float(value)
            self.data_parameter[idx].value = value
            self.data_parameter[idx].blockSignals(False)

    @QtCore.pyqtSlot(str, object)
    def change_global_choice(self, _, value):
        idx = self.global_parameter.index(self.sender())
        self.glob_values[idx] = value
        if self.data_values[self.comboBox.currentData()][idx] is None:
            self.data_parameter[idx].blockSignals(True)
            self.data_parameter[idx].value = value
            self.data_parameter[idx].blockSignals(False)

    def change_single_parameter(self, value: float = None, sender=None):
        if sender is None:
            sender = self.sender()
        idx = self.data_parameter.index(sender)
        self.data_values[self.comboBox.currentData()][idx] = value
        # look for global parameter values if value is reset, ie None
        if value is None:
            self.change_data(self.comboBox.currentIndex())

    def change_single_choice(self, _, value, sender=None):
        if sender is None:
            sender = self.sender()
        idx = self.data_parameter.index(sender)
        self.data_values[self.comboBox.currentData()][idx] = value

    @QtCore.pyqtSlot(object)
    def look_for_value(self, sender):
        self.value_requested.emit(self.global_parameter.index(sender))

    @QtCore.pyqtSlot()
    def make_global(self):
        # disable single parameter if it is set global, enable if global is unset
        widget = self.sender()
        idx = self.global_parameter.index(widget)
        enable = (widget.global_checkbox.checkState() == QtCore.Qt.Unchecked) and (widget.is_linked is None)
        self.data_parameter[idx].setEnabled(enable)

    def select_next_preview(self, direction):
        curr_idx = self.comboBox.currentIndex()
        next_idx = (curr_idx + direction) % self.comboBox.count()
        self.comboBox.setCurrentIndex(next_idx)

    @QtCore.pyqtSlot(int, name='on_comboBox_currentIndexChanged')
    def change_data(self, idx: int):
        # new dataset is selected, look for locally set parameter else use global values
        sid = self.comboBox.itemData(idx)
        if sid not in self.data_values:
            self._make_parameter(sid)

        for i, value in enumerate(self.data_values[sid]):
            w = self.data_parameter[i]
            w.blockSignals(True)
            try:
                w.show_as_local_parameter(value is not None)
            except AttributeError:
                pass

            if value is None:
                w.value = self.glob_values[i]
            else:
                w.value = value
            w.blockSignals(False)

    def _make_parameter(self, sid):
        if sid not in self.data_values:
            self.data_values[sid] = [None] * len(self.data_parameter)

    def get_parameter(self, use_func=None):
        bds = []
        is_global = []
        is_fixed = []
        globs = []
        is_linked = []

        for g in self.global_parameter:
            if isinstance(g, FitModelWidget):
                p_i, bds_i, fixed_i, global_i, link_i = g.get_parameter()

                globs.append(p_i)
                bds.append(bds_i)
                is_fixed.append(fixed_i)
                is_global.append(global_i)
                is_linked.append(link_i)

        lb, ub = list(zip(*bds))

        data_parameter = {}
        if use_func is None:
            use_func = list(self.data_values.keys())

        global_p = None
        for sid, parameter in self.data_values.items():
            if sid not in use_func:
                continue

            kw_p = {}
            p = []
            if global_p is None:
                global_p = {'value': [], 'idx': [], 'var': [], 'ub': [], 'lb': []}

            for i, (p_i, g) in enumerate(zip(parameter, self.global_parameter)):
                if isinstance(g, FitModelWidget):
                    if (p_i is None) or is_global[i]:
                        p.append(globs[i])
                        if is_global[i]:
                            if i not in global_p['idx']:
                                global_p['value'].append(globs[i])
                                global_p['idx'].append(i)
                                global_p['var'].append(is_fixed[i])
                                global_p['ub'].append(ub[i])
                                global_p['lb'].append(lb[i])
                    else:
                        p.append(p_i)

                    try:
                        if p[i] > ub[i]:
                            raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
                    except TypeError:
                        pass

                    try:
                        if p[i] < lb[i]:
                            raise ValueError(f'Parameter {g.name} is outside bounds ({lb[i]}, {ub[i]})')
                    except TypeError:
                        pass

                else:
                    if p_i is None:
                        kw_p.update(g.value)
                    elif isinstance(p_i, dict):
                        kw_p.update(p_i)
                    else:
                        kw_p[g.argname] = p_i

            data_parameter[sid] = (p, kw_p)

        return data_parameter, lb, ub, is_fixed, global_p, is_linked

    def set_parameter(self, set_id: str | None, parameter: list[float]) -> int:
        num_parameter = list(filter(lambda g: not isinstance(g, SelectionWidget), self.global_parameter))
        param_len = len(num_parameter)
        if set_id is None:
            for i, g in enumerate(num_parameter):
                val = parameter[i]
                g.set_parameter(val)
                self.glob_values[i] = val

        else:
            new_param = self.data_values[set_id]
            min_len = min(param_len, len(new_param))
            for i in range(min_len):
                new_param[i] = parameter[i]

            self.change_data(self.comboBox.currentIndex())

        return param_len


class ParameterSingleWidget(QtWidgets.QWidget):
    valueChanged = QtCore.pyqtSignal(object)
    removeSingleValue = QtCore.pyqtSignal()

    def __init__(self, name: str, parent=None):
        super().__init__(parent=parent)

        self._init_ui()

        self._name = name
        self.label.setText(convert(name))
        self.label.setToolTip('IIf this is bold then this parameter is only for this data. otherwise the general parameter is used and displayed')

        # self.value_line.setValidator(QtGui.QDoubleValidator())
        self.value_line.textChanged.connect(lambda: self.valueChanged.emit(self.value) if self.value is not None else 0)
        self.reset_button.clicked.connect(lambda x: self.removeSingleValue.emit())

    def _init_ui(self):
        layout = QtWidgets.QHBoxLayout(self)
        layout.setContentsMargins(2, 2, 2, 2)
        layout.setSpacing(2)

        self.label = QtWidgets.QLabel(self)
        layout.addWidget(self.label)

        layout.addSpacerItem(QtWidgets.QSpacerItem(0, 0, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum))

        self.value_line = QtWidgets.QLineEdit(self)
        self.value_line.textEdited.connect(lambda x: self.show_as_local_parameter(True))
        layout.addWidget(self.value_line)

        self.reset_button = QtWidgets.QToolButton(self)
        self.reset_button.setText('Use global')
        self.reset_button.clicked.connect(lambda: self.show_as_local_parameter(False))
        layout.addWidget(self.reset_button)

        self.setLayout(layout)

    @property
    def value(self) -> float:
        try:
            return float(self.value_line.text().replace(',', '.'))
        except ValueError:
            return 0.0

    @value.setter
    def value(self, val):
        # self.value_line.setText(f'{float(val):.5g}')
        self.value_line.setText(f'{val}')

    def show_as_local_parameter(self, is_local):
        if is_local:
            self.label.setStyleSheet('font-weight: bold;')
        else:
            self.label.setStyleSheet('')