import copy
from collections import namedtuple

import numpy as np


class ImageData:
    _Projections = namedtuple("Projections", ("horizontal", "vertical"))
    _Values = namedtuple("Values", ("x", "y"))
    _Units = namedtuple("Units", ("horizontal", "vertical"))

    class NoCameraImageType(Exception):
        def __init__(self):
            super().__init__("No camera image type")

    class ImageSizeChangedWhileTakingBackground(Exception):
        def __init__(self):
            super().__init__("Image size changed while taking background")

    class ImageSizeChangedWhileRunning(Exception):
        def __init__(self):
            super().__init__("Image size changed while running")

    class _ImageSizeChanged(Exception):  # marker exception
        pass

    @classmethod
    def get_relevant_image_and_offset(cls, image):
        raise NotImplementedError

    def __init__(self, cameras):
        self._cameras = cameras
        self._camera = None
        self._address = None
        self._use_encoder = None
        self._image = None
        self._projections_values_single = None
        self._projections_values = None
        self._projections_units = self._Units("px", "px")
        self._projections_values_calibrated_in_nm_single = None
        self._projections_values_calibrated_in_nm = None
        self._projections_units_calibrated_in_nm = self._Units("nm", "px")
        self._projections_values_calibrated_in_ev_single = None
        self._projections_values_calibrated_in_ev = None
        self._projections_units_calibrated_in_ev = self._Units("eV", "px")
        self._projections_values_train = None
        self._projections_values_train_calibrated_in_ev = None
        self._to_take_background = False
        self._to_correct_background = False
        self._to_average = False
        self._samples = None
        self._projections_values_buffer = []
        self._background_samples = None
        self._background_sample = None
        self._background_image = None
        self._corrected_image = None
        self._is_new_data_available = None
        self._readout = None
        self._event = None

    @property
    def cameras(self):
        return self._cameras.keys()

    @property
    def camera(self):
        return self._camera

    @property
    def address(self):
        return self._address

    @property
    def use_encoder(self):
        return self._use_encoder

    @property
    def image(self):
        return self._image

    @property
    def projections_values_single(self):
        return self._projections_values_single

    @property
    def projections_values(self):
        return self._projections_values

    @property
    def projections_units(self):
        return self._projections_units

    @property
    def projections_values_calibrated_in_nm_single(self):
        return self._projections_values_calibrated_in_nm_single

    @property
    def projections_values_calibrated_in_nm(self):
        return self._projections_values_calibrated_in_nm

    @property
    def projections_units_calibrated_in_nm(self):
        return self._projections_units_calibrated_in_nm

    @property
    def projections_values_calibrated_in_ev_single(self):
        return self._projections_values_calibrated_in_ev_single

    @property
    def projections_values_calibrated_in_ev(self):
        return self._projections_values_calibrated_in_ev

    @property
    def projections_units_calibrated_in_ev(self):
        return self._projections_units_calibrated_in_ev

    @property
    def projections_values_train(self):
        return self._projections_values_train

    @property
    def projections_values_train_calibrated_in_ev(self):
        return self._projections_values_train_calibrated_in_ev

    @property
    def to_take_background(self):
        return self._to_take_background

    @property
    def to_average(self):
        return self._to_average

    @property
    def is_new_data_available(self):
        return self._is_new_data_available

    @property
    def samples(self):
        return self._samples

    def use_camera(self, checked_camera):
        self._camera = checked_camera
        self._address = self._cameras[checked_camera]["address"]
        self._use_encoder = self._cameras[checked_camera]["encoder"]
        self._is_new_data_available = False

    def initialize(self, background_samples):
        self._to_take_background = True
        self._background_samples = background_samples
        self._background_sample = 0

    def set_background_correction(self, state):
        self._to_correct_background = state

    def set_averaging(self, state):
        self._to_average = state

    def set_samples(self, samples):
        self._samples = samples
        self._projections_values_buffer.clear()

    def reset(self):
        self._to_take_background = False

    def update(self, image_readout):
        if image_readout.type != "IMAGE":
            raise self.NoCameraImageType()
        self._readout = image_readout
        self._is_new_data_available = False
        if self._readout.event != self._event:
            self._update_data()
            self._event = self._readout.event
            self._is_new_data_available = True

    def _update_data(self):
        self._image = self._readout.data.astype(np.float32)
        if self._to_take_background:
            self._apply_background_taking()
        elif self._to_correct_background:
            self._apply_background_correction()
        self.update_projections()

    def _apply_background_taking(self):
        try:
            self._background_sample, self._background_image = self._averaging(
                self._background_sample, self._background_image, self._image)
        except self._ImageSizeChanged:
            raise self.ImageSizeChangedWhileTakingBackground() from None
        finally:
            self._image = self._background_image

    def _apply_background_correction(self):
        if self._image.shape == self._background_image.shape:
            self._corrected_image = self._image - self._background_image
            self._image = self._corrected_image
        else:
            if self._corrected_image is None:
                self._image = self._background_image
            else:
                self._image = self._corrected_image
            raise self.ImageSizeChangedWhileRunning()

    @classmethod
    def _averaging(cls, samples, average, image):
        samples += 1
        if samples == 1:
            average = image
        elif image.shape == average.shape:
            inverse_samples = 1 / samples  # use this to avoid memory issue due to implicit casting
            average = (1 - inverse_samples) * average + inverse_samples * image
        else:
            raise cls._ImageSizeChanged()
        return samples, average

    def update_projections(self):
        image, offset = self.get_relevant_image_and_offset(self._image)
        self._projections_values = self._Projections(self._calculate_projection(image, offset, axis=0),
                                                     self._calculate_projection(image, offset, axis=1))

        self._projections_values_single = copy.deepcopy(self._projections_values)
        self._projections_values_calibrated_in_nm_single = copy.deepcopy(self._projections_values)
        self._projections_values_calibrated_in_ev_single = copy.deepcopy(self._projections_values)

        if self._to_average:
            self._average_projections()

        self._projections_values_calibrated_in_nm = copy.deepcopy(self._projections_values)
        self._projections_values_calibrated_in_ev = copy.deepcopy(self._projections_values)

        if self._camera in ("Gotthard", "Gotthard2"):
            self._projections_values_train = self._Values(np.arange(0, image.shape[1], dtype=np.float32) + offset[0],
                                                          image)
            self._projections_values_train_calibrated_in_ev = copy.deepcopy(self._projections_values_train)

    def _calculate_projection(self, image, offset, axis):
        return self._Values(np.arange(0, len(np.nansum(image, axis=axis)), dtype=np.float32) + offset[axis],
                            np.nansum(image, axis=axis))

    def set_empty(self):
        self._image = None
        self._projections_values_buffer.clear()

    def clear_buffer(self):
        self._projections_values_buffer.clear()

    def is_not_empty(self):
        return self._image is not None

    def get_background_taking_progress(self):
        return int(self._background_sample / self._background_samples * 100)

    def _average_projections(self):
        if len(self._projections_values_buffer) == self._samples:
            self._projections_values_buffer.pop(0)
        self._projections_values_buffer.append(self._projections_values)

        horizontal_projection = self._Values(
            self._projections_values.horizontal.x,
            np.array([p.horizontal.y for p in self._projections_values_buffer]).mean(axis=0))

        vertical_projection = self._Values(
            self._projections_values.vertical.x,
            np.array([p.vertical.y for p in self._projections_values_buffer]).mean(axis=0))

        self._projections_values = self._Projections(horizontal_projection, vertical_projection)
