import numpy as np
from PyQt5.QtCore import Qt, QObject, pyqtSignal
from PyQt5.QtWidgets import QGraphicsLineItem
from doocspie.util.colors import Colors
from pyqtgraph import GraphicsLayoutWidget, PlotDataItem, mkPen


class LoggerPlotter(QObject):
    moved = pyqtSignal()

    _DESY_CYAN = np.array(Colors.DESY.CYAN) * 255
    _DESY_ORANGE = np.array(Colors.DESY.ORANGE) * 255

    class _MovableLine(QGraphicsLineItem):

        def __init__(self, parent, plot, color, width):
            super().__init__()
            self._parent = parent
            self._plot = plot
            self.setPen(mkPen(color=color, width=width))
            self.setAcceptHoverEvents(True)
            self._saved_pen = None
            self._delta_when_pressed = None
            self._delta_when_set = 0

        def hoverEnterEvent(self, event):
            self._saved_pen = self.pen()
            self.setPen(mkPen(color=self._saved_pen.color().getRgb(), width=2 * self._saved_pen.width()))
            event.ignore()

        def hoverLeaveEvent(self, event):
            self.setPen(self._saved_pen)
            event.ignore()

        def mousePressEvent(self, event):
            if event.button() == Qt.LeftButton:
                event.accept()
                self._plot.enableAutoRange(axis="y", enable=False)
                self._delta_when_pressed = self.mapToParent(event.pos()) - self.pos()
            else:
                event.ignore()

        def mouseReleaseEvent(self, event):
            self._plot.enableAutoRange(axis="y", enable=True)
            self._delta_when_set = (self.mapToParent(event.pos()) - self._delta_when_pressed).y()
            event.ignore()

        def mouseMoveEvent(self, event):
            self.setPos(0, (self.mapToParent(event.pos()) - self._delta_when_pressed).y())
            self._parent.moved.emit()

        def set_position(self, position):
            for _ in range(2):  # do twice is intentional
                current_position = self.line().y1() + self.pos().y() - self._delta_when_set
                self._delta_when_set = position - current_position
                self.setPos(0, self._delta_when_set)

    def __init__(self, *labels):
        super().__init__()
        self._horizontal_widget = GraphicsLayoutWidget()

        self._horizontal_metrics_plot = self._horizontal_widget.addPlot()
        self._horizontal_metrics = PlotDataItem()
        self._horizontal_metrics_plot.addItem(self._horizontal_metrics)
        self._horizontal_movable_line_position = {label: 0 for label in labels}
        self._horizontal_value = {label: None for label in labels}
        self._horizontal_movable_line = self._MovableLine(self, self._horizontal_metrics_plot,
                                                          color=self._DESY_ORANGE, width=2)
        self._horizontal_metrics_plot.addItem(self._horizontal_movable_line, ignoreBounds=True)

        self._horizontal_widget.nextRow()
        self._horizontal_projection_plot = self._horizontal_widget.addPlot()
        self._horizontal_projection = PlotDataItem()
        self._horizontal_projection_plot.addItem(self._horizontal_projection)
        self._horizontal_fit = PlotDataItem()
        self._horizontal_projection_plot.addItem(self._horizontal_fit)

        self._vertical_widget = GraphicsLayoutWidget()

        self._vertical_metrics_plot = self._vertical_widget.addPlot()
        self._vertical_metrics = PlotDataItem()
        self._vertical_metrics_plot.addItem(self._vertical_metrics)
        self._vertical_movable_line_position = {label: 0 for label in labels}
        self._vertical_value = {label: None for label in labels}
        self._vertical_movable_line = self._MovableLine(self, self._vertical_metrics_plot,
                                                        color=self._DESY_ORANGE, width=2)
        self._vertical_metrics_plot.addItem(self._vertical_movable_line, ignoreBounds=True)

        self._vertical_widget.nextRow()
        self._vertical_projection_plot = self._vertical_widget.addPlot()
        self._vertical_projection = PlotDataItem()
        self._vertical_projection_plot.addItem(self._vertical_projection)
        self._vertical_fit = PlotDataItem()
        self._vertical_projection_plot.addItem(self._vertical_fit)

        self._initialize_plots()
        self._is_movable_line_visible = False

    def _initialize_plots(self):
        self._initialize_horizontal_projection_plot()
        self._initialize_horizontal_fit_parameter_plot()
        self._initialize_vertical_projection_plot()
        self._initialize_vertical_fit_parameter_plot()

    def _initialize_horizontal_projection_plot(self):
        self._initialize_design(self._horizontal_projection_plot)
        self._horizontal_fit.setData(pen={"color": self._DESY_CYAN, "width": 2})

    def _initialize_horizontal_fit_parameter_plot(self):
        self._initialize_design(self._horizontal_metrics_plot)

    def _initialize_vertical_projection_plot(self):
        self._initialize_design(self._vertical_projection_plot)
        self._vertical_fit.setData(pen={"color": self._DESY_CYAN, "width": 2})

    def _initialize_vertical_fit_parameter_plot(self):
        self._initialize_design(self._vertical_metrics_plot)

    @staticmethod
    def _initialize_design(plot):
        plot.setMenuEnabled(False)
        plot.showAxis("top", show=True)
        plot.showAxis("right", show=True)
        for pos in "top", "right", "left", "bottom":
            plot.getAxis(pos).setStyle(showValues=False, tickLength=0)
        plot.setLabel("bottom", "")
        plot.setLabel("left", "")

    @property
    def horizontal_widget(self):
        return self._horizontal_widget

    @property
    def vertical_widget(self):
        return self._vertical_widget

    @property
    def is_movable_line_visible(self):
        return self._is_movable_line_visible

    def set_projections(self, values, units):
        self._set_horizontal_projection(values.horizontal, units.horizontal)
        self._set_vertical_projection(values.vertical, units.vertical)

    def _set_horizontal_projection(self, values, units):
        self._horizontal_projection.setData(x=values.x, y=values.y)
        self._horizontal_projection_plot.getAxis("left").setStyle(showValues=True)
        self._horizontal_projection_plot.getAxis("bottom").setStyle(showValues=True)
        self._horizontal_projection_plot.setLabel("left", "Intensity (arb. u.)")
        self._horizontal_projection_plot.setLabel("bottom", f"Horizontal ({units})")

    def _set_vertical_projection(self, values, units):
        self._vertical_projection.setData(x=values.x, y=values.y)
        self._vertical_projection_plot.getAxis("left").setStyle(showValues=True)
        self._vertical_projection_plot.getAxis("bottom").setStyle(showValues=True)
        self._vertical_projection_plot.setLabel("left", "Intensity (arb. u.)")
        self._vertical_projection_plot.setLabel("bottom", f"Vertical ({units})")

    def set_fits(self, fits):
        self._set_fit(self._horizontal_fit, fits.horizontal)
        self._set_fit(self._vertical_fit, fits.vertical)

    @staticmethod
    def _set_fit(fit, fit_data):
        if fit_data is None:
            fit.clear()
        else:
            fit.setData(x=fit_data.x, y=fit_data.y)

    def set_horizontal_metrics(self, values, unit, label):
        self._set_metrics(self._horizontal_metrics, self._horizontal_metrics_plot, values, label, unit)
        self._set_movable_line(self._horizontal_movable_line_position, label, self._horizontal_movable_line, values,
                               self._horizontal_value)

    def set_vertical_metrics(self, values, unit, label):
        self._set_metrics(self._vertical_metrics, self._vertical_metrics_plot, values, label, unit)
        self._set_movable_line(self._vertical_movable_line_position, label, self._vertical_movable_line, values,
                               self._vertical_value)

    def _set_movable_line(self, movable_line_position, label, movable_line, values, value):
        if values:
            finite_values = np.array(values)[np.isfinite(values)]
            if finite_values.any():
                if value[label] is None:
                    value[label] = finite_values[0]
                movable_line_position[label] = movable_line.pos().y()
                movable_line.setLine(1, value[label], len(values), value[label])
                movable_line.setVisible(self._is_movable_line_visible)

    @staticmethod
    def _set_metrics(metrics, metrics_plot, values, label, unit):
        x = np.arange(1, len(values) + 1)
        finite_values = np.isfinite(values)
        metrics.setData(x=x[finite_values], y=np.array(values)[finite_values])
        metrics_plot.getAxis("left").setStyle(showValues=True)
        metrics_plot.getAxis("bottom").setStyle(showValues=True)
        metrics_plot.setLabel("left", f"{label} ({unit})")
        metrics_plot.setLabel("bottom", "Sample")

    def set_horizontal_movable_line(self, label):
        self._horizontal_movable_line.setY(self._horizontal_movable_line_position[label])

    def set_vertical_movable_line(self, label):
        self._vertical_movable_line.setY(self._vertical_movable_line_position[label])

    def reset(self):
        self._reset(self._horizontal_movable_line)
        self._reset(self._vertical_movable_line)

    @staticmethod
    def _reset(movable_line):
        movable_line.setVisible(False)

    def init(self):
        self._init(self._horizontal_movable_line, self._horizontal_movable_line_position, self._horizontal_value)
        self._init(self._vertical_movable_line, self._vertical_movable_line_position, self._vertical_value)

    @staticmethod
    def _init(movable_line, movable_line_position, value):
        movable_line.setY(0)
        for key in movable_line_position:
            movable_line_position[key] = 0
        for key in value:
            value[key] = None

    def set_movable_line_visible(self, state):
        self._is_movable_line_visible = state
        if self._horizontal_metrics.getData()[0] is not None and any(self._horizontal_metrics.getData()[0]):
            self._horizontal_movable_line.setVisible(state)
        if self._vertical_metrics.getData()[0] is not None and any(self._vertical_metrics.getData()[0]):
            self._vertical_movable_line.setVisible(state)

    def clear(self):
        self._horizontal_metrics.clear()
        self._horizontal_projection.clear()
        self._horizontal_fit.clear()
        self._vertical_metrics.clear()
        self._vertical_projection.clear()
        self._vertical_fit.clear()
        self.reset()
        self._initialize_design(self._horizontal_metrics_plot)
        self._initialize_design(self._horizontal_projection_plot)
        self._initialize_design(self._vertical_metrics_plot)
        self._initialize_design(self._vertical_projection_plot)

    def get_horizontal_label(self, label):
        if self._horizontal_value[label] is None:
            return None
        return self._horizontal_value[label] + self._horizontal_movable_line.pos().y()

    def get_vertical_label(self, label):
        if self._vertical_value[label] is None:
            return None
        return self._vertical_value[label] + self._vertical_movable_line.pos().y()

    def set_position(self, horizontal, vertical):
        self._horizontal_movable_line.set_position(horizontal)
        self._vertical_movable_line.set_position(vertical)

    def set_labels(self, horizontal_label, vertical_label):
        self._horizontal_movable_line_position[horizontal_label] = self._horizontal_movable_line.pos().y()
        self._vertical_movable_line_position[vertical_label] = self._vertical_movable_line.pos().y()
