import numpy as np
from calculation_methods.py_calculations.calculation_constants.constants import FOURPLMODEL
from calculation_methods.py_calculations.error_handler.error_messages import ErrorMessages
from calculation_methods.py_calculations import calculation_utils as utils

class FourParameterLogisticRegression:
    def __init__(self, input_data, standard_values, blank_values, weighted=False, forced_to_zero=False):
        self.input_data = input_data
        self.standard_values = standard_values
        self.blank_values = blank_values
        self.weighted = weighted
        self.forced_to_zero = forced_to_zero

        # Process input data
        self.processed_data = utils.process_input_data(input_data)

        # Validate and prepare data for fitting
        self.x_values, self.y_values, self.weights = utils.validate_points(
            self.processed_data['x'], self.processed_data['y'], weighted
        )

        # Ensure consistent length
        min_length = min(len(self.x_values), len(self.y_values))
        self.x_values = self.x_values[:min_length]
        self.y_values = self.y_values[:min_length]
        if self.weights is not None:
            self.weights = self.weights[:min_length]
            
        # Ensure we have at least 4 points for fitting
        if len(self.x_values) < 4:
            utils.handle_error(ErrorMessages.ERROR_INSUFFICIENT_DATA_POINTS)

        # Fit the 4PL model using centralized utility
        self.params = utils.calculate_params_with_lmfit(
            x_values=self.x_values,
            y_values=self.y_values,
            model_type= FOURPLMODEL,
            weighted=self.weighted,
            forced_to_zero=self.forced_to_zero,
            weights=self.weights
        )
        self.coefficients = self.extract_coefficients()

    def extract_coefficients(self):
        #Extract coefficients as a dictionary.
        try:
            if not isinstance(self.params, list):
                utils.handle_error(ErrorMessages.ERROR_INSUFFICIENT_DATA_POINTS)
            a, b, c, d = self.params
            return {'a': float(a), 'b': float(b), 'c': float(c), 'd': float(d)}
        except Exception as e:
            utils.handle_error(ErrorMessages.ERROR_GET_COEFFICIENTS.format(str(e)))

    def predict(self, x_values):
        #Predict y-values for given x-values.
        if x_values is None or len(x_values) == 0:
            return np.array([])
        x_array = np.array(x_values, dtype=float)
        return np.array([utils.calculate_curve_y(x, FOURPLMODEL, self.coefficients) for x in x_array])

    def calculate_concentration(self, y):
        #Calculate concentration (x) from a y-value.
        try:
            # Handle forced_to_zero case (fixes coefficient extraction)
            if self.forced_to_zero:
                a, b, c, d = (0, *self.params[1:])  # Force a = 0
            else:
                a, b, c, d = (self.coefficients[k] for k in ['a', 'b', 'c', 'd'])
            if not np.isfinite(y) or (y - d == 0):
                return None
            base = ((a - d) / (y - d)) - 1
            if base <= 0:
                return None
            concentration = c * (base ** (1 / b))
            return concentration if np.isfinite(concentration) and concentration >= 0 else None
        except (OverflowError, ZeroDivisionError, ValueError):
            return None

    def calculate_cv(self, sample_type):
        #Calculate coefficient of variation for a sample type.
        concentrations = [
            self.calculate_concentration(y) for entry in self.input_data
            if entry['identifier'].startswith(sample_type) for y in entry['y']
        ]
        valid_concentrations = [c for c in concentrations if c is not None]
        return utils.compute_statistical_cv(valid_concentrations)

    def get_metrics(self):
        #Calculate and return all metrics.
        y_pred = self.predict(self.x_values)
        coeff_list = [self.coefficients[key] for key in ['a', 'b', 'c', 'd']]
        
        metrics = utils.calculate_statistics(y_true=self.y_values, y_pred=y_pred, coefficients=coeff_list)
        metrics['LLD'] = utils.calculate_lld(blank_y=self.processed_data['blank_y'], model_type= FOURPLMODEL, params=self.coefficients, x_values=self.x_values)
        additional_table_details = utils.generate_table_by_sample_groups(input_data=self.input_data, model_type= FOURPLMODEL, params=self.coefficients, predict_func=self.calculate_concentration, cv_func=self.calculate_cv)
        curve_data_points = utils.generate_interpolated_curve_data(x_data=self.x_values, y_data=self.y_values, curve_type=FOURPLMODEL, coefficients=self.coefficients)

        return {
            "method": FOURPLMODEL,
            "function": self.coefficients,
            "metrics": metrics,
            "Additional_Table_Details": additional_table_details,
            "Curve_Data_Points": curve_data_points
        }