from concurrent.futures import ThreadPoolExecutor
import concurrent
import numpy as np
from scipy.interpolate import CubicSpline
from calculation_methods.py_calculations.calculation_constants.constants import CUBICMODEL
from calculation_methods.py_calculations.error_handler.error_messages import ErrorMessages
from calculation_methods.py_calculations import calculation_utils as utils

class CubicSplineInterpolator:
    def __init__(self, input_data, standard_values, blank_values, weighted=False):
        self.input_data = input_data
        self.standard_values = standard_values
        self.blank_values = blank_values
        self.weighted = weighted
        self.processed_data = self.process_cubic_spline_data()
        self.x_unique = self.processed_data['x_unique']
        self.y_unique = self.processed_data['y_unique']
        if len(self.x_unique) < 3 or len(self.y_unique) < 3:
            utils.handle_error(ErrorMessages.ERROR_INSUFFICIENT_DATA_POINTS)
        if not np.all(np.isfinite(self.x_unique)) or not np.all(np.isfinite(self.y_unique)):
            utils.handle_error(ErrorMessages.ERROR_INVALID_DATA)
        self.spline = CubicSpline(self.x_unique, self.y_unique, bc_type='natural')
        self.coefficients = self.extract_coefficients()
    
    def process_cubic_spline_data(self):
    #Process data for cubic spline, generating unique and sorted x and y values.
        try:
            def calculate_avg_values_by_identifier(data):
                identifier_dict = {}
                for entry in data:
                    identifier = entry.get('identifier')
                    x = entry.get('x', [])
                    y = entry.get('y', [])
                    if identifier not in identifier_dict:
                        identifier_dict[identifier] = {'x': [], 'y': []}
                    identifier_dict[identifier]['x'].extend(x)
                    identifier_dict[identifier]['y'].extend(y)
                
                avg_x_values = []
                avg_y_values = []
                for identifier, values in identifier_dict.items():
                    x_values = values['x']
                    y_values = values['y']
                    has_x_values = len(x_values) > 0 and x_values[0] is not None
                    has_y_values = len(y_values) > 0 and y_values[0] is not None
                    if has_x_values and has_y_values:
                        avg_x = np.mean(x_values)
                        avg_y = np.mean(y_values)
                        avg_x_values.append(avg_x)
                        avg_y_values.append(avg_y)
                return np.array(avg_x_values), np.array(avg_y_values)

            x_input_avg, y_input_avg = calculate_avg_values_by_identifier(self.input_data)
            x_standard_avg, y_standard_avg = calculate_avg_values_by_identifier(self.standard_values)
            x_blank_avg, y_blank_avg = calculate_avg_values_by_identifier(self.blank_values)

            x_combined = np.concatenate([x_input_avg, x_standard_avg, x_blank_avg])
            y_combined = np.concatenate([y_input_avg, y_standard_avg, y_blank_avg])

            x_unique, unique_indices = np.unique(x_combined, return_index=True)
            y_unique = y_combined[unique_indices]

            sorted_indices = np.argsort(x_unique)
            x_unique = x_unique[sorted_indices]
            y_unique = y_unique[sorted_indices]

            return {
                'x_unique': x_unique,
                'y_unique': y_unique,
                'input_data': self.input_data,  
                'standard_values': self.standard_values,
                'blank_values': self.blank_values
            }
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_PROCESS_VALUES)

    def extract_coefficients(self):
        try:
            def extract_single_coefficient(i):
                self.a = self.spline.c[3, i]
                self.b = self.spline.c[2, i]
                self.c = self.spline.c[1, i]
                self.d = self.spline.c[0, i]
                x_start = self.x_unique[i]
                x_end = self.x_unique[i + 1]
                identifier = next((entry['identifier'] for entry in self.standard_values if entry['x'] == [x_end]), None)
                return {
                    'identifier': identifier,
                    'x_start': x_start,
                    'x_end': x_end,
                    'a': self.a,
                    'b': self.b,
                    'c': self.c,
                    'd': self.d
                }
            with concurrent.futures.ThreadPoolExecutor() as executor:
                coefficients_table = list(executor.map(extract_single_coefficient, range(len(self.x_unique) - 1)))
            return coefficients_table
            
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS)

    def predict_x(self, y_val):
        try:
            def predict_single_x(i):
                coeff = self.coefficients[i]
                d, c, b, a = coeff['d'], coeff['c'], coeff['b'], coeff['a']
                coeffs = [d, c, b, a - y_val]
                roots = np.roots(coeffs)
                real_roots = [root.real for root in roots if np.isreal(root) and coeff['x_start'] <= root.real <= coeff['x_end']]
                return real_roots[0] if real_roots else None
            with ThreadPoolExecutor() as executor:
                x_pred_list = list(executor.map(predict_single_x, range(len(self.x_unique) - 1)))
            return next((x_pred for x_pred in x_pred_list if x_pred is not None), None)
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_PREDICT_X)
            return None

    def calculate_cv(self, sample_type):
        concentrations = [self.predict_x(y) for entry in self.input_data if entry['identifier'].startswith(sample_type) for y in entry['y']]
        return utils.compute_statistical_cv([c for c in concentrations if c is not None])

    def get_coefficients(self):
        # Use the full list of coefficients instead of just the first one
        params = self.coefficients  # Pass the entire list of segments
        metrics = utils.calculate_statistics(y_true=self.y_unique, y_pred=self.spline(self.x_unique), coefficients=self.coefficients)
        additional_table_details = utils.generate_table_by_sample_groups(
            input_data=self.input_data, model_type=CUBICMODEL, params=params, predict_func=self.predict_x, cv_func=self.calculate_cv
        )
        curve_data_points = utils.generate_interpolated_curve_data(x_data= self.x_unique, y_data=self.y_unique, curve_type =CUBICMODEL, coefficients =params)
        return {
            "method": "cubic_spline",
            "coefficients": self.coefficients,
            "metrics": metrics,
            "Additional_Table_Details": additional_table_details,
            "Curve_Data_Points": curve_data_points
        }
