import concurrent.futures
from collections import defaultdict
import numpy as np
import logging
log = logging.getLogger('django')
from calculation_methods.py_calculations.calculation_constants.constants import FIVEPLMODEL, LINEARMODEL, LLD_CONSTANT
from calculation_methods.py_calculations.error_handler.error_messages import ErrorMessages
from calculation_methods.py_calculations import calculation_utils as utils

class LinearRegression:
    def __init__(self, input_data, standard_values, blank_values, weighted=False, force_to_zero=True):
        self.data = input_data
        self.standard_values = standard_values
        self.blank_values = blank_values
        self.weighted = weighted
        self.force_to_zero = force_to_zero
        self.lld_constant = LLD_CONSTANT

        self.x_values = None
        self.y_values = None
        self.weights = None
        self.a = None
        self.b = None

    def execute(self):
        is_insufficient_data = False  
        invalid_data = False
        try:
            # Process input data
            self.processed_data = utils.process_input_data(self.data)  
            # Validate and convert x, y to numpy arrays, optionally compute weights
            self.x_values, self.y_values, self.weights = utils.validate_points(self.processed_data['x'], self.processed_data['y'], self.weighted)
            
            is_insufficient_data = len(self.x_values) < 2 or len(self.y_values) < 2
            invalid_data = not np.all(np.isfinite(self.x_values)) or not np.all(np.isfinite(self.y_values))   
            if(is_insufficient_data or invalid_data):
                utils.handle_error(ErrorMessages.ERROR_INVALID_DATA)
            else:
                self.a, self.b = self.get_a_b() 
                self.coefficients = (self.a, self.b)
                metrics = self.get_metrics()
                additional_table_details = self.get_table_data()
                curve_data_points = utils.generate_interpolated_curve_data(x_data =self.x_values, y_data =self.y_values, curve_type =LINEARMODEL, coefficients =self.coefficients)

                return {
                    "method": LINEARMODEL,
                    "function": {"a": self.a, "b": self.b},
                    "metrics": metrics,
                    "Additional_Table_Details": additional_table_details,
                    "Curve_Data_Points": curve_data_points
                }
        except Exception as e:
            # Check for sufficient data points
            if is_insufficient_data:
                utils.handle_error(ErrorMessages.ERROR_INSUFFICIENT_DATA_POINTS)
            # Check for invalid data (e.g., NaN or Inf)
            elif invalid_data:
                utils.handle_error(ErrorMessages.ERROR_INVALID_DATA)
            else:
                utils.handle_error(str(e))  # Handle other exceptions
    # ---------- Regression analysis ---------- #

    def get_a_b(self):
        try:
            if self.force_to_zero:
                return self.get_a_b_force_to_zero()
            elif self.weighted:
                return self.get_a_b_weighted()
            else:
                return self.get_a_b_unweighted()
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS)

    def get_a_b_force_to_zero(self):
        # Forcing regression through the origin
        try:
            
            if self.weighted:
                numerator = np.sum(self.weights * self.x_values * self.y_values)
                denominator = np.sum(self.weights * self.x_values ** 2)
            else:
                numerator = np.sum(self.x_values * self.y_values)
                denominator = np.sum(self.x_values ** 2)

            b = numerator / denominator if denominator != 0 else 0
            a = 0
            return a, b
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS)
            
    def get_a_b_weighted(self):
        # Weighted linear regression
        try:
            x_mean = np.average(self.x_values, weights=self.weights)
            y_mean = np.average(self.y_values, weights=self.weights)

            b_num = np.sum(self.weights * (self.x_values - x_mean) * (self.y_values - y_mean))
            b_den = np.sum(self.weights * (self.x_values - x_mean) ** 2)

            b = b_num / b_den if b_den != 0 else 0
            a = y_mean - (b * x_mean)
            return a, b
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS)
            
    def get_a_b_unweighted(self):
        # Unweighted linear regression
        try:
            x_mean, y_mean = np.mean(self.x_values), np.mean(self.y_values)
            denominator = np.sum((self.x_values - x_mean) ** 2)
            numerator = np.sum((self.x_values - x_mean) * (self.y_values - y_mean))

            b = numerator / denominator if denominator != 0 else 0
            a = y_mean - (b * x_mean)
            return a, b
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS)

    # ---------- Metrics ---------- #

    def get_metrics(self):
        try:
            
            if self.a is None or self.b is None:
                return None
            # Compute the residual sum of squares
            x_unique = np.unique(self.x_values)
            y_averages = np.array([np.mean(self.y_values[self.x_values == e]) for e in x_unique])
            y_predicted = self.predict_unweighted(x_unique)
            rss = np.sum((y_averages - y_predicted) ** 2)
            
            # Compute the coefficient of determination. That is, the proportion of the variation in the dependent
            # variable (Y-axis) that is predictable from the independent variable (X-axis).
            ss = np.sum((y_averages - np.mean(y_averages) )** 2)
            # Use utility function to calculate R-squared
            r_squared = utils.calculate_r_squared(rss, ss)

            # Calculate adjusted R-squared using utility function
            n = len(y_averages)
            adjusted_r_squared = utils.calculate_adjusted_r_squared(r_squared, n, 1)
            # Compute the lower limit of detection
            lld = utils.calculate_lld(blank_y=self.processed_data['blank_y'], model_type=LINEARMODEL, params=self.coefficients, x_values=self.x_values)

            return {
                "RSS": rss,
                "R_Squared": r_squared,
                "Adjusted_R_Squared": adjusted_r_squared,
                "LLD": lld
            }
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_METRICS)

    def predict_unweighted(self, values):
        slope = self.b if self.b is not None and self.b != 0 else 1
        y_intercept = self.a if self.a is not None and self.a != 0 else 0
        return y_intercept + (slope * values)
            
    # ---------- Table Data ---------- #
    def get_table_data(self):
        try:
            # Generate table details by grouping data by sample groups
            sample_groups = self.generate_additional_table_by_sample_groups()

            # Calculate metrics for each group in parallel
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = [executor.submit(self.group_and_calculate_metrics, group) for group in sample_groups]
                table_data = [future.result() for future in concurrent.futures.as_completed(futures)]
            
            return table_data
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GENERATE_TABLE)

    def group_and_calculate_metrics(self, group):
        try:
            identifier = group['identifier']
            coordinates = group['coordinates']
            x_values = group['x_values']
            y_values = group['y_values']
            curve_y = []
            
            # Predict x for each y_value using LinearRegression's method
            slope = self.b if self.b is not None and self.b != 0 else 1
            y_intercept = self.a if self.a is not None and self.a != 0 else 0
            
            # Ensure the lengths of x_values and y_values match
            if len(x_values) != len(y_values):
                utils.handle_error(ErrorMessages.ERROR_INVALID_REQUEST)
            
            # Calculate x_predicted for each y_value
            x_predicted = [(y - y_intercept) / slope for y in y_values]
            
            # Calculate x_difference only if x_values are non-zero
            x_difference = [
                ((x_pred - x) / x) * 100 if x is not None and x != 0 else None
                for x_pred, x in zip(x_predicted, x_values)
            ]
            cv = utils.compute_statistical_cv(x_predicted)
            # Calculate curve_y only for standards
            if identifier.upper().startswith("S"):
                for x in x_values:
                    curve_y = [utils.calculate_curve_y(x, LINEARMODEL, self.coefficients)]

            # Calculate coefficient of variance
            return {
                "identifier": identifier,
                "coordinates": coordinates,
                "x_values": x_values,
                "x_predicted": x_predicted,
                "x_difference": x_difference,
                "y_values": y_values,
                "cv": cv,
                "curve_y" : curve_y
            }
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_CALCULATE_METRICS)


    def generate_additional_table_by_sample_groups(self):
        try:
            grouped_data = defaultdict(lambda: {'x_values': [], 'y_values': [], 'coordinates': [], 'identifier': None})
            # Group entries by identifier
            for entry in self.data:
                identifier = entry['identifier']
                coordinates = entry['coordinates']
                x = entry['x'] if entry['x'] else [0] * len(entry['y'])
                y = entry['y']
                
                # Append coordinates to the list for the identifier
                grouped_data[identifier]['x_values'].extend(x)
                grouped_data[identifier]['y_values'].extend(y)
                grouped_data[identifier]['coordinates'].append(coordinates.upper())  
                grouped_data[identifier]['identifier'] = identifier
                     
            return list(grouped_data.values())

        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GENERATE_TABLE)