import numpy as np
from pyqtgraph import setConfigOptions, GraphicsLayoutWidget, PlotDataItem, LabelItem, ImageItem, ROI, QtCore, \
    HistogramLUTItem

from doocspie.util.colors import Colors


class ViewerPlotter:
    setConfigOptions(imageAxisOrder="row-major")

    _DESY_CYAN = np.array(Colors.DESY.CYAN) * 255

    _WHITESPACE = "&nbsp;"
    _NEWLINE = "<br>"

    def __init__(self):
        self._widget = GraphicsLayoutWidget()

        self._horizontal_parameters_label = LabelItem(justify="left")
        self._widget.addItem(self._horizontal_parameters_label)

        self._horizontal_projection = PlotDataItem()
        self._horizontal_projection_plot = self._widget.addPlot()
        self._horizontal_projection_plot.addItem(self._horizontal_projection)
        self._horizontal_fit = PlotDataItem()
        self._horizontal_projection_plot.addItem(self._horizontal_fit)

        self._vertical_parameters_label = LabelItem(justify="left")
        self._widget.addItem(self._vertical_parameters_label)

        self._widget.nextRow()
        self._image = ImageItem()
        self._histogram = HistogramLUTItem()
        self._histogram.setImageItem(self._image)
        self._widget.addItem(self._histogram)
        self._image_plot = self._widget.addPlot()
        self._image_plot.addItem(self._image)

        self._roi = ROI((0, 0), scaleSnap=True, translateSnap=True)
        self._roi.updated_once = False
        self._roi_handles = []
        self._image_plot.addItem(self._roi)

        self._vertical_projection = PlotDataItem()
        self._vertical_projection_plot = self._widget.addPlot()
        self._vertical_projection_plot.addItem(self._vertical_projection)
        self._vertical_fit = PlotDataItem()
        self._vertical_projection_plot.addItem(self._vertical_fit)

        self._initialize_plots()
        self._range_calibration = None
        self._first_readout = True

    def _initialize_plots(self):
        self._initialize_horizontal_projection_plot()
        self._initialize_fit_labels()
        self._initialize_image_plot()
        self._initialize_roi()
        self._initialize_vertical_projection_plot()

    def _initialize_horizontal_projection_plot(self):
        self._initialize_design(self._horizontal_projection_plot)
        self._horizontal_projection_plot.setMouseEnabled(False, False)
        self._horizontal_projection_plot.hideButtons()
        self._horizontal_projection_plot.setMaximumHeight(225)
        self._horizontal_fit.setData(pen={"color": self._DESY_CYAN, "width": 2})

    def _initialize_fit_labels(self):
        self._horizontal_parameters_label.setMaximumWidth(200)
        self._horizontal_parameters_label.setMaximumHeight(200)
        self._vertical_parameters_label.setMaximumWidth(200)
        self._vertical_parameters_label.setMaximumHeight(200)

    def _initialize_image_plot(self):
        self._initialize_design(self._image_plot)
        self._image_plot.invertY(True)

    def _initialize_roi(self):
        self._roi.handleSize = 8
        self._add_handles()
        self._roi.setVisible(False)

    def _add_handles(self):
        self._roi_handles.append(self._roi.addScaleHandle((0, 0.5), (0.5, 0.5)))
        self._roi_handles.append(self._roi.addScaleHandle((1, 0.5), (0.5, 0.5)))
        self._roi_handles.append(self._roi.addScaleHandle((0.5, 0), (0.5, 0.5)))
        self._roi_handles.append(self._roi.addScaleHandle((0.5, 1), (0.5, 0.5)))

    def _initialize_vertical_projection_plot(self):
        self._initialize_design(self._vertical_projection_plot)
        self._vertical_projection_plot.setMouseEnabled(False, False)
        self._vertical_projection_plot.hideButtons()
        self._vertical_projection_plot.setMaximumWidth(200)
        self._vertical_projection_plot.invertY(True)
        self._vertical_fit.setData(pen={"color": self._DESY_CYAN, "width": 2})

    @staticmethod
    def _initialize_design(plot):
        plot.setMenuEnabled(False)
        plot.showAxis("top", show=True)
        plot.showAxis("right", show=True)
        for position in "top", "right", "left", "bottom":
            plot.getAxis(position).setStyle(showValues=False, tickLength=0)

    @property
    def widget(self):
        return self._widget

    @property
    def roi(self):
        return self._roi

    def lock_aspect_ratio(self, state):
        self._image_plot.getViewBox().setAspectLocked(state)

    def is_requested_roi_in_bounds(self, x_position, y_position, width, height):
        if x_position - width // 2 < 0 or x_position + width // 2 > self._image.width():
            return False
        if y_position - height // 2 < 0 or y_position + height // 2 > self._image.height():
            return False
        return True

    def set_roi_position(self, x_position, y_position):
        self._roi.setPos(x_position, y=y_position, update=False, finish=False)

    def set_roi_size(self, width, height):
        self._roi.setSize((width, height), update=True, finish=True)

    def fit_roi_to_image(self):
        self._roi.setSize((self._image.width(), self._image.height()))
        self._roi.setPos(self._image.pos())

    def get_relevant_image_and_offset(self, image):
        if self._is_image_size_changed(image):
            self._image.setImage(image)
            self._roi.maxBounds = QtCore.QRectF(0, 0, self._image.width(), self._image.height())
            self.fit_roi_to_image()

        pos = np.array(self._roi.pos(), dtype=int)
        size = np.array(self._roi.size(), dtype=int)
        return image[pos[1]:pos[1] + size[1], pos[0]:pos[0] + size[0]], self._roi.pos()

    def _is_image_size_changed(self, image):
        return self._image.height() != image.shape[0] or self._image.width() != image.shape[1]

    def add_to_connections_for_changes(self, action):
        self._image_plot.sigRangeChanged.connect(action)
        self._roi.sigRegionChanged.connect(action)

    def enable_region_of_interest(self, state):
        self._roi.setVisible(state)

    def set_image(self, image):
        if self._is_image_size_changed(image):
            self._image.setImage(image, autoLevels=self._first_readout)
            self._roi.maxBounds = QtCore.QRectF(0, 0, self._image.width(), self._image.height())
        else:
            self._image.setImage(image, autoLevels=self._first_readout)

        if self._first_readout:
            self._histogram.setLevels(image.min(), image.max())
            self._first_readout = False

    def reset(self):
        self._first_readout = True

    def set_projections(self, values, units, range_calibration, invert_x_axis):
        self._horizontal_projection.setData(x=values.horizontal.x, y=values.horizontal.y)
        self._horizontal_projection_plot.getAxis("bottom").setStyle(showValues=True)
        self._horizontal_projection_plot.setLabel("bottom", f"Horizontal ({units.horizontal})")

        self._vertical_projection.setData(x=values.vertical.y, y=values.vertical.x)
        self._vertical_projection_plot.getAxis("left").setStyle(showValues=True)
        self._vertical_projection_plot.setLabel("left", f"Vertical ({units.vertical})")

        self._horizontal_projection_plot.getViewBox().invertX(invert_x_axis)

        self._range_calibration = range_calibration
        self.update_axis()

    def set_fits(self, fits):
        if fits.horizontal is None:
            self._horizontal_fit.clear()
        else:
            self._horizontal_fit.setData(x=fits.horizontal.x, y=fits.horizontal.y)

        if fits.vertical is None:
            self._vertical_fit.clear()
        else:
            self._vertical_fit.setData(x=fits.vertical.y, y=fits.vertical.x)

    def set_fit_parameters(self, fit_parameters, use_encoder, encoder_readout):
        if use_encoder:
            self._horizontal_parameters_label.setText(
                self._get_parameters_label_text(fit_parameters.horizontal, "horizontal parameters", spaces=33,
                                                encoder_readout=encoder_readout))
        else:
            self._horizontal_parameters_label.setText(
                self._get_parameters_label_text(fit_parameters.horizontal, "horizontal parameters", spaces=31))

        self._vertical_parameters_label.setText(
            self._get_parameters_label_text(fit_parameters.vertical, "vertical parameters", spaces=28,
                                            vertical_offset=4, horizontal_offset=0))

    def _get_parameters_label_text(self, fit_parameters, label, spaces, encoder_readout=None,
                                   vertical_offset=0, horizontal_offset=0):
        offset = "offset"
        slope = "slope"
        amplitude = "amplitude"
        center = "center"
        fwhm = "fwhm"
        if encoder_readout:
            rows = (f"encoder readout: {encoder_readout:.3f}" + 2 * self._NEWLINE,)
        else:
            rows = ()

        rows += (vertical_offset * self._NEWLINE,
                 horizontal_offset * self._WHITESPACE + label + self._NEWLINE,
                 horizontal_offset * self._WHITESPACE + "-" * spaces + self._NEWLINE,
                 self._get_parameters_label_row(offset, fit_parameters.offset, 1 + horizontal_offset),
                 self._get_parameters_label_row(slope, fit_parameters.slope, 1 + horizontal_offset),
                 self._get_parameters_label_row(amplitude, fit_parameters.amplitude, 1 + horizontal_offset),
                 self._get_parameters_label_row(center, fit_parameters.center, 1 + horizontal_offset),
                 self._get_parameters_label_row(fwhm, fit_parameters.fwhm, 1 + horizontal_offset))
        return "".join(row for row in rows)

    def _get_parameters_label_row(self, parameter_name, parameter, leading_blanks):
        return leading_blanks * self._WHITESPACE + f"{parameter_name}: {parameter:.2f}" + self._NEWLINE

    def clear(self):
        self.reset()
        self._horizontal_projection.clear()
        self._horizontal_fit.clear()
        self._horizontal_projection_plot.setLabel("bottom", "")
        self._horizontal_projection_plot.getAxis("bottom").setStyle(showValues=False, tickLength=0)
        self._image.clear()
        self._vertical_projection.clear()
        self._vertical_fit.clear()
        self._vertical_projection_plot.setLabel("left", "")
        self._vertical_projection_plot.getAxis("left").setStyle(showValues=False, tickLength=0)
        self._vertical_parameters_label.setText("")

    def activate_roi(self, state):
        self._roi.translatable = state
        if state:
            if not self._roi_handles:
                self._add_handles()
        else:
            for handle in self._roi_handles:
                self._roi.removeHandle(handle)
            self._roi_handles.clear()

    def update_axis(self):
        if self._roi.isVisible():
            self._horizontal_projection_plot.enableAutoRange()
            self._vertical_projection_plot.enableAutoRange()
        else:
            self._horizontal_projection_plot.setXRange(
                *self._range_calibration(np.array(self._image_plot.getAxis("bottom").range),
                                         np.array(self._image_plot.getAxis("left").range))[0], padding=0)
            self._vertical_projection_plot.setYRange(*self._image_plot.getAxis("left").range, padding=0)
