import warnings

import numpy as np
from scipy.optimize import curve_fit, OptimizeWarning


class FitService:
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    warnings.filterwarnings("ignore", category=OptimizeWarning)

    def __init__(self, image_data, fit_data):
        self._image_data = image_data
        self._fit_data = fit_data

    @property
    def fit_data(self):
        return self._fit_data

    def update(self):
        self._fit_data.update_single(*self._get_gaussian_fit(self._image_data.projections_values_single.horizontal),
                                     *self._get_gaussian_fit(self._image_data.projections_values_single.vertical))
        self._fit_data.update_calibrated_in_nm_single(
            *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_nm_single.horizontal),
            *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_nm_single.vertical))
        self._fit_data.update_calibrated_in_ev_single(
            *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_ev_single.horizontal),
            *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_ev_single.vertical))

        if self._image_data.to_average:
            self._fit_data.update(*self._get_gaussian_fit(self._image_data.projections_values.horizontal),
                                  *self._get_gaussian_fit(self._image_data.projections_values.vertical))

            self._fit_data.update_calibrated_in_nm(
                *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_nm.horizontal),
                *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_nm.vertical))

            self._fit_data.update_calibrated_in_ev(
                *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_ev.horizontal),
                *self._get_gaussian_fit(self._image_data.projections_values_calibrated_in_ev.vertical))
        else:
            self._fit_data.update_from_single()

    def _get_gaussian_fit(self, projection):
        fit_parameters = self._get_gaussian_fit_parameters(projection.x, projection.y)
        if fit_parameters is not None:
            x = np.linspace(projection.x[0], projection.x[-1], num=int(10 * len(projection.x)))
            fit = x, self._gauss_model(x, *fit_parameters)
            return fit, fit_parameters
        return None, None

    def _get_gaussian_fit_parameters(self, x, y):
        try:
            offset = y[0]
            slope = (y[-1] - y[0]) / (x[-1] - x[0])
            y_corrected = y - offset - slope * (x - x[0])  # correct offset and slope before guessing start parameters
            amplitude = np.nanmax(y_corrected)
            position = x[np.nanargmax(y_corrected)]
            sigma = np.abs((x[-1] - x[0]) / 50)  # sigma guess is found empirically and should be used with caution
            guess = (offset, slope, amplitude, position, sigma)
            valid = ~(np.isnan(x) | np.isnan(y))
            fit_parameters, *_ = curve_fit(self._gauss_model, x[valid], y[valid], p0=guess)
            fit_parameters[4] = np.abs(fit_parameters[4])
            return fit_parameters
        except (RuntimeError, ValueError, TypeError, IndexError):
            return None

    @staticmethod
    def _gauss_model(x, offset, slope, amplitude, center, sigma):
        return offset + slope * x + amplitude * np.exp(-(x - center) ** 2 / 2 / sigma ** 2)

    def get_baseline_fit(self, x, y):
        fit_parameters = self._get_gaussian_fit_parameters(x, y)
        if fit_parameters is not None:
            fit = x, fit_parameters[0] + fit_parameters[1] * (x - x[0])
            return fit
        return None
