import warnings

import numba as nb
import numpy as np
from scipy.constants import e, hbar
from scipy.optimize import curve_fit, OptimizeWarning


class CorrelationService:
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    warnings.filterwarnings("ignore", category=OptimizeWarning)

    _PLANCK_BAR_CONSTANT = hbar / e

    def __init__(self, image_data, correlation_data):
        self._image_data = image_data
        self._correlation_data = correlation_data
        self._correlation = None
        self._delta_omega = None

    @property
    def correlation_data(self):
        return self._correlation_data

    def update(self):
        if self._correlation_data.to_take_data:
            if self._image_data.camera == "Spectrometer@FL22":
                self._correlation_data.update(self._image_data.projections_values_calibrated_in_ev.horizontal,
                                              self._image_data.projections_values.horizontal)
            elif self._image_data.camera in ("Gotthard", "Gotthard2"):
                self._correlation_data.update_train(self._image_data.projections_values_train_calibrated_in_ev,
                                                    self._image_data.projections_values_train)

    def process(self, energy_range, energy_bin):
        self._correlate_for(energy_range)
        self._apply_binning_with(energy_bin)
        self._correlation_data.set(self._PLANCK_BAR_CONSTANT * self._delta_omega, g2_data=self._correlation,
                                   gauss_model_data=self._fit_g2_using(self._g2_gauss_model),
                                   flattop_model_data=self._fit_g2_using(self._g2_flattop_model))

    def _correlate_for(self, energy_range):
        energy = self._correlation_data.energy
        energy_step = (np.max(energy) - np.min(energy)) / (len(energy) - 1)

        correlation_range = int(energy_range / energy_step)
        if correlation_range > len(energy):
            correlation_range = len(energy)
        elif correlation_range < 3:  # the fit later needs at least 3 data points
            correlation_range = 3

        self._correlation = self._correlate_center(correlation_range)
        self._correlation[np.isnan(self._correlation)] = 0
        self._correlation[np.isinf(self._correlation)] = 0

    def _correlate_center(self, correlation_range):
        samples, number_of_values = np.array(self._correlation_data.intensities).shape
        zeros = np.zeros((correlation_range, samples))
        values = np.r_[zeros, np.array(self._correlation_data.intensities).T, zeros]
        correlation = np.zeros((number_of_values, correlation_range))
        return self._correlate_center_numba(correlation, correlation_range, values)

    @staticmethod
    @nb.jit("double[:,:](double[:,:], int32, double[:,:])", nopython=True, nogil=True)
    def _correlate_center_numba(correlation, correlation_range, values):
        samples = values.shape[1]
        number_of_values = values.shape[0] - correlation_range * 2
        for i in range(number_of_values):
            for j in range(correlation_range):
                if j % 2:
                    index_l = int(i - (j - 1) / 2 + correlation_range)
                    index_r = int(i + (j - 1) / 2 + 1 + correlation_range)
                else:
                    index_l = int(i - j / 2 + correlation_range)
                    index_r = int(i + j / 2 + correlation_range)

                mean_s = 0
                mean_l = 0
                mean_r = 0
                for k in range(samples):
                    mean_s += values[index_l, k] * values[index_r, k]
                    mean_l += values[index_l, k]
                    mean_r += values[index_r, k]

                mean_s /= samples
                mean_l /= samples
                mean_r /= samples

                if mean_l == 0 or mean_r == 0:
                    correlation[i, j] = 1
                else:
                    correlation[i, j] = mean_s / mean_l / mean_r
        return correlation

    def _apply_binning_with(self, energy_bin):
        freq_bin = int(energy_bin / abs(self._correlation_data.energy[1] - self._correlation_data.energy[0]))
        omega = self._correlation_data.energy / self._PLANCK_BAR_CONSTANT

        omega_new = omega
        intensities_new = np.array(self._correlation_data.intensities).T

        if freq_bin > omega.shape[0]:  # avoid "over-binning"
            freq_bin = omega.shape[0]

        if freq_bin > 1:
            omega_new = omega[:(omega.shape[0] // freq_bin) * freq_bin][::freq_bin]
            intensities_new = np.array(self._correlation_data.intensities).T[
                              :(omega.shape[0] // freq_bin) * freq_bin][::freq_bin]
            corr_b = self._correlation[:(self._correlation.shape[0] // freq_bin) * freq_bin]
            self._correlation = np.mean(corr_b.reshape(int(corr_b.shape[0] / freq_bin), freq_bin, corr_b.shape[1]),
                                        axis=1)

        phen = sum(omega_new * self._PLANCK_BAR_CONSTANT * intensities_new.mean(axis=1)) / sum(
            intensities_new.mean(axis=1))
        idx = (np.abs(omega_new * self._PLANCK_BAR_CONSTANT - phen)).argmin()
        self._correlation = self._correlation[idx]
        self._delta_omega = np.linspace(0, np.abs(omega[1] - omega[0]) * len(self._correlation), len(self._correlation))

    def _fit_g2_using(self, function):
        fit_parameters = self._get_model_fit_parameters(self._delta_omega, self._correlation, function)
        if fit_parameters is not None:
            fit = function(self._delta_omega, *fit_parameters)
            return fit, np.abs(fit_parameters[2] * 1e15)  # 2nd return value is pulse duration in fs
        return None, None

    @staticmethod
    def _get_model_fit_parameters(x, y, function):
        try:
            offset = y[-1]
            amplitude = y[0] - y[-1]
            duration_fwhm = 2 * np.pi / x[int(len(x) / 2)]
            guess = (offset, amplitude, duration_fwhm)
            fit_parameters, *_ = curve_fit(function, x, y, p0=guess)
            return fit_parameters
        except (RuntimeError, ValueError):
            return None

    @staticmethod
    def _g2_gauss_model(delta_omega, offset, amplitude, duration_fwhm):
        x = delta_omega * duration_fwhm / (2 * np.sqrt(2 * np.log(2)))
        return offset + amplitude * np.exp(-x ** 2)

    @staticmethod
    def _g2_flattop_model(delta_omega, offset, amplitude, duration_fwhm):
        x = delta_omega * duration_fwhm / 2
        return offset + amplitude * np.sinc(x / np.pi) ** 2
