from numpy.polynomial.polynomial import Polynomial
from scipy.constants import c, e, h


class CalibrationService:
    SPECTROMETER_FL22 = "Spectrometer@FL22"
    GOTTHARD = "Gotthard"
    GOTTHARD2 = "Gotthard2"

    WAVELENGTH = "Wavelength"
    ENERGY = "Energy"

    _SPEED_OF_LIGHT = c
    _PLANCK_CONSTANT = h / e

    def __init__(self, calibration_data, encoder_data, image_data):
        self._calibration_data = calibration_data
        self._encoder_data = encoder_data
        self._image_data = image_data
        self._menu_items = (self.WAVELENGTH, self.ENERGY)
        self._camera = None
        self._polynomial = None
        self._calibration = None

    @staticmethod
    def _get_polynomial_with(coefficients):
        return Polynomial([coefficients["a_0"], coefficients["a_1"], coefficients["a_2"], coefficients["a_3"],
                           coefficients["a_4"], coefficients["a_5"]])

    @staticmethod
    def _get_calibration_with(points):
        dispersion = (points["lambda_2"] - points["lambda_1"]) / (points["x_2"] - points["x_1"])
        return lambda x: points["lambda_1"] + dispersion * (x - points["x_1"])

    @property
    def menu_items(self):
        return self._menu_items

    @property
    def camera(self):
        return self._camera

    def use_camera(self, camera):
        self._calibration_data.use_camera(camera)
        self._camera = camera
        if self._camera == self.SPECTROMETER_FL22:
            self.update_polynomial()
        elif self._camera in (self.GOTTHARD, self.GOTTHARD2):
            self.update_calibration()

    def get_encoder_readout(self):
        return self._encoder_data.readout

    def get_coefficient(self, index):
        return str(self._calibration_data.polynomial_coefficients["a_" + str(index)])

    def set_coefficient(self, index, coefficient):
        self._calibration_data.polynomial_coefficients["a_" + str(index)] = coefficient

    def get_fit_center(self, index):
        return str(self._calibration_data.fit_center["y" if index else "x"])

    def set_fit_center(self, index, fit_center):
        self._calibration_data.fit_center["y" if index else "x"] = fit_center

    def get_image_mm_per_px(self):
        return str(self._calibration_data.image["millimeter_per_pixel"])

    def set_image_mm_per_px(self, image_mm_per_px):
        self._calibration_data.image["millimeter_per_pixel"] = image_mm_per_px

    def get_calibration_lambda(self, index):
        return str(self._calibration_data.calibration_points["lambda_" + str(index)])

    def set_calibration_lambda(self, index, _lambda):
        self._calibration_data.calibration_points["lambda_" + str(index)] = _lambda

    def get_calibration_x(self, index):
        return str(self._calibration_data.calibration_points["x_" + str(index)])

    def set_calibration_x(self, index, x):
        self._calibration_data.calibration_points["x_" + str(index)] = x

    def update_polynomial(self):
        self._polynomial = self._get_polynomial_with(self._calibration_data.polynomial_coefficients)

    def update_calibration(self):
        self._calibration = self._get_calibration_with(self._calibration_data.calibration_points)

    def save_calibration(self):
        self._calibration_data.save()

    def update(self):
        self._update_wavelength_calibration()
        self._update_energy_calibration()
        if self._camera in (self.GOTTHARD, self.GOTTHARD2):
            self._update_train_energy_calibration()

    def _update_wavelength_calibration(self):
        (self._image_data.projections_values_calibrated_in_nm.horizontal.x[...],
         self._image_data.projections_values_calibrated_in_nm.horizontal.y[...]) = self.get_wavelength_calibration(
            self._image_data.projections_values.horizontal.x, self._image_data.projections_values.horizontal.y)
        (self._image_data.projections_values_calibrated_in_nm_single.horizontal.x[...],
         self._image_data.projections_values_calibrated_in_nm_single.horizontal.y[...]) = (
            self.get_wavelength_calibration(self._image_data.projections_values_single.horizontal.x,
                                            self._image_data.projections_values_single.horizontal.y))

    def _update_energy_calibration(self):
        (self._image_data.projections_values_calibrated_in_ev.horizontal.x[...],
         self._image_data.projections_values_calibrated_in_ev.horizontal.y[...]) = self.get_energy_calibration(
            self._image_data.projections_values.horizontal.x, self._image_data.projections_values.horizontal.y)
        (self._image_data.projections_values_calibrated_in_ev_single.horizontal.x[...],
         self._image_data.projections_values_calibrated_in_ev_single.horizontal.y[...]) = (
            self.get_energy_calibration(self._image_data.projections_values_single.horizontal.x,
                                        self._image_data.projections_values_single.horizontal.y))

    @staticmethod
    def get_pixel_calibration(x, y):
        return x, y

    def _get_length_calibration(self, x):
        return (x - self._calibration_data.fit_center["x"]) * self._calibration_data.image["millimeter_per_pixel"]

    def get_wavelength_calibration(self, x, y):
        if self._camera == self.SPECTROMETER_FL22:
            return self._polynomial(self._encoder_data.readout + self._get_length_calibration(x)), y
        elif self._camera in (self.GOTTHARD, self.GOTTHARD2):
            return self._calibration(x), y

    def get_energy_calibration(self, x, y):
        energy = 1 / (self.get_wavelength_calibration(x, y)[0] * 1e-9) * self._SPEED_OF_LIGHT * self._PLANCK_CONSTANT
        intensity = y * self._SPEED_OF_LIGHT * self._PLANCK_CONSTANT / energy ** 2 / 1e-9
        return energy, intensity

    def _update_train_energy_calibration(self):
        self._image_data.projections_values_train_calibrated_in_ev.x[...] = self.get_energy_calibration(
            self._image_data.projections_values_train.x, self._image_data.projections_values_train.y[0, :])[0]
        for pulse in range(self._image_data.projections_values_train.y.shape[0]):
            self._image_data.projections_values_train_calibrated_in_ev.y[pulse, ...] = self.get_energy_calibration(
                self._image_data.projections_values_train.x, self._image_data.projections_values_train.y[pulse, :])[1]
