import copy

import numpy as np


class CorrelationData:
    class ImageSizeChangedWhileTakingCorrelationData(Exception):
        def __init__(self):
            super().__init__("Image size changed while taking correlation data")

    def __init__(self):
        self._energy = None
        self._intensities = []
        self._intensities_full = None
        self._horizontal_axis = None
        self._baselines = []
        self._sample = None
        self._samples = None
        self._delta_energy = None
        self._g2_data = None
        self._g2_fit_gauss = None
        self._pulse_duration_gauss = None
        self._g2_fit_flattop = None
        self._pulse_duration_flattop = None
        self._to_take_data = False
        self._baseline_fit = None

    @property
    def energy(self):
        return self._energy

    @property
    def intensities(self):
        return self._intensities

    @property
    def baseline(self):
        return self._horizontal_axis, np.array(self._baselines).mean(axis=0)

    @property
    def delta_energy(self):
        return self._delta_energy

    @property
    def g2_data(self):
        return self._g2_data

    @property
    def g2_fit_gauss(self):
        return self._g2_fit_gauss

    @property
    def pulse_duration_gauss(self):
        return self._pulse_duration_gauss

    @property
    def g2_fit_flattop(self):
        return self._g2_fit_flattop

    @property
    def pulse_duration_flattop(self):
        return self._pulse_duration_flattop

    @property
    def to_take_data(self):
        return self._to_take_data

    def initialize(self, samples):
        self._to_take_data = True
        self._energy = None
        self._intensities.clear()
        self._intensities_full = None
        self._horizontal_axis = None
        self._baselines.clear()
        self._sample = 0
        self._samples = samples

    def reset(self):
        self._to_take_data = False
        self._intensities_full = self._intensities

    def update(self, spectrum_ev, spectrum_px):
        self._sample += 1

        if self._energy is None:
            self._energy = spectrum_ev.x
        self._intensities.append(spectrum_ev.y)

        if self._horizontal_axis is None:
            self._horizontal_axis = spectrum_px.x
        self._baselines.append(spectrum_px.y)

        if len(self._intensities) >= 2 and len(self._intensities[-2]) != len(self._intensities[-1]):
            self._intensities.pop()
            raise self.ImageSizeChangedWhileTakingCorrelationData() from None

    def update_train(self, spectra_ev, spectra_px):
        self._sample += 1

        if self._energy is None:
            self._energy = spectra_ev.x

        for pulse in spectra_ev.y:
            self._intensities.append(pulse)

        if self._horizontal_axis is None:
            self._horizontal_axis = spectra_px.x

        for pulse in spectra_px.y:
            self._baselines.append(pulse)

        if len(self._intensities) >= 2 and len(self._intensities[-2]) != len(self._intensities[-1]):
            self._intensities.pop()
            raise self.ImageSizeChangedWhileTakingCorrelationData() from None

    def get_samples_taking_progress(self):
        return int(self._sample / self._samples * 100)

    def set(self, delta_energy, g2_data, gauss_model_data, flattop_model_data):
        self._delta_energy = delta_energy
        self._g2_data = g2_data
        self._g2_fit_gauss = gauss_model_data[0]
        self._pulse_duration_gauss = gauss_model_data[1]
        self._g2_fit_flattop = flattop_model_data[0]
        self._pulse_duration_flattop = flattop_model_data[1]

    def set_baseline_fit(self, baseline_fit):
        self._baseline_fit = baseline_fit

    def initialize_corrections(self):
        self._intensities = copy.deepcopy(self._intensities_full)

    def subtract_baseline(self):
        if self._baseline_fit is not None:
            for index, _ in enumerate(self._intensities):
                self._intensities[index] -= self._baseline_fit[1]

    def set_empty(self):
        self._intensities.clear()
