import numpy as np
#from scipy.optimize import curve_fit
from collections import defaultdict
from tabulate import tabulate

class FourPL_Weighted:
    def __init__(self, plate_data_x, plate_data_y, plate_format, weighted=False):
        self.plate_data_x = plate_data_x
        self.plate_data_y = plate_data_y
        self.plate_format = plate_format
        self.weighted = weighted  # Whether to apply weights
        self.x_arr = None
        self.y_arr = None
        self.a = None
        self.b = None
        self.c = None
        self.d = None
        self.y_pred = None
        self.regression_metrics = None
        self.lld = None
        self.weights = None

    @staticmethod
    def four_param_logistic(x, a, b, c, d):
        epsilon = 1e-10  # Small constant to avoid division by zero and log(0)
        x = np.maximum(x, epsilon)  # Ensure all x values are positive
        return d + ((a - d) / (1 + (x / (c + epsilon)) ** b))

    @staticmethod
    def inverse_four_param_logistic(y, a, b, c, d):
        epsilon = 1e-10  # Small constant to avoid log(0)
        y = np.maximum(y, epsilon)  # Ensure all y values are positive
        return c * (((a - y) / (y - d)) ** (1 / b))

    def calculate_averages_and_generate_arrays(self):
        x_values = defaultdict(list)
        y_values = defaultdict(list)
        all_wells = []  # Store all wells for weight calculation

        # Flatten the data and filter for standards ('S')
        for plate in self.plate_format:
            for i, position in enumerate(self.plate_format[plate]):
                if position.startswith('S'):
                    x_values[position].append(self.plate_data_x[plate][i])
                    y_values[position].append(self.plate_data_y[plate][i])
                    all_wells.append((plate, self.plate_data_x[plate][i], self.plate_data_y[plate][i]))

        # Calculate average x and y values for each standard type
        average_x = {position: sum(values) / len(values) for position, values in x_values.items()}
        average_y = {position: sum(values) / len(values) for position, values in y_values.items()}

        # Generate separate arrays for x and y values
        self.x_arr = np.array(list(average_x.values()))
        self.y_arr = np.array(list(average_y.values()))

        # Calculate weights for the y-values being used in the fit
        self.weights = np.array([1 / (y ** 2) if self.weighted else 1 for y in self.y_arr])
        # Convert weights to standard deviations
        self.std_devs = self.y_arr


    def fit_curve_and_calculate_metrics(self):
        # Fit the model to the data using curve_fit
        try:
            #params, covariance = curve_fit(self.four_param_logistic, self.x_arr, self.y_arr,sigma=self.std_devs)
            #self.a, self.b, self.c, self.d = params

            # Calculate predicted x and y values using the four-parameter logistic function
            self.y_pred = self.four_param_logistic(self.x_arr, self.a, self.b, self.c, self.d)

            # Calculate regression metrics
            self.regression_metrics = self.calculate_regression_metrics(self.y_arr, self.y_pred)

            # Calculate LLD
            self.lld = self.calculate_lld()

        except RuntimeError:
            print("Curve fitting failed. Check your initial parameters and bounds.")

    @staticmethod
    def calculate_regression_metrics(y_actual, y_pred):
        # Compute mean of actual y values
        y_mean = np.mean(y_actual)
        # Compute residual sum of squares (RSS)
        rss = np.sum((y_actual - y_pred) ** 2)
        # Compute total sum of squares (TSS)
        tss = np.sum((y_actual - y_mean) ** 2)
        # Compute R-squared
        r_squared = 1 - (rss / tss)
        # Compute adjusted R-squared
        n = len(y_actual)
        div = (n - 1) / (n - 2)
        adjusted_r_squared = 1 - ((1 - r_squared) * div) if n > 2 else r_squared

        # Store metrics in a dictionary
        metrics = {
            'R_squared': r_squared,
            'Adjusted_R_squared': adjusted_r_squared,
            'RSS': rss,
        }

        return metrics

    def calculate_lld(self):
        blank_values = [float(self.plate_data_y[plate][i]) for plate, wells in self.plate_format.items() for i, well in
                        enumerate(wells) if well.startswith("B") and self.plate_data_y[plate][i] != ""]
        if not blank_values:
            return None

        mean_blank = np.mean(blank_values)
        std_blank = np.std(blank_values)

        standard_concentrations = [float(x) for plate, wells in self.plate_format.items() for i, well in
                                   enumerate(wells) if well.startswith("S") for x in self.plate_data_x[plate] if
                                   x != ""]

        if not standard_concentrations:
            return None

        curvemin = self.four_param_logistic(min(standard_concentrations), self.a, self.b, self.c, self.d)
        curvemax = self.four_param_logistic(max(standard_concentrations), self.a, self.b, self.c, self.d)

        concentration_predictor = lambda x: (self.c * ((self.a - x) / (x - self.d)) ** (1 / self.b))

        if curvemax > curvemin:
            lld = concentration_predictor(mean_blank + (2 * 1.645 * std_blank))
        else:
            lld = concentration_predictor(mean_blank - (2 * 1.645 * std_blank))
        return lld

    def calculate_cv(self, sample_concentrations):
        if len(sample_concentrations) < 2:
            return None
        mean = np.mean(sample_concentrations)
        std_dev = np.std(sample_concentrations, ddof=1)
        return (std_dev / mean) * 100 if mean != 0 else 0

    def coefficients(self):
        # Calculate coefficients
        self.calculate_averages_and_generate_arrays()
        self.fit_curve_and_calculate_metrics()

        # Print coefficients
        print("Estimated parameters:")
        print("a:", self.a)
        print("b:", self.b)
        print("c:", self.c)
        print("d:", self.d)

    def display_regression_metrics(self):
        # Calculate regression metrics
        self.calculate_averages_and_generate_arrays()
        self.fit_curve_and_calculate_metrics()

        # Print regression metrics
        print("Regression Metrics:")
        for key, value in self.regression_metrics.items():
            print(f"{key}: {value}")
        # Print LLD value
        print(f"LLD: {self.lld}")

    def generate_and_print_additional_table(self):
        # Calculate coefficients and fit curve
        self.calculate_averages_and_generate_arrays()
        self.fit_curve_and_calculate_metrics()

        a, b, c, d = self.a, self.b, self.c, self.d
        additional_table = []
        well_counter = {plate: 1 for plate in self.plate_format.keys()}
        cv_dict = defaultdict(list)

        # Calculate predicted concentrations and populate cv_dict for %CV calculation
        for plate, wells in self.plate_format.items():
            for i, well in enumerate(wells):
                fi_endpoint = self.plate_data_y[plate][i]
                concentration_predicted = self.inverse_four_param_logistic(fi_endpoint, a, b, c, d) if None not in (a, b, c, d) else None
                sample_type = well[:2]
                cv_dict[sample_type].append(concentration_predicted)

        # Calculate %CV for each sample type
        cv_values = {sample_type: self.calculate_cv(concentrations) for sample_type, concentrations in cv_dict.items()}

        # Generate the additional table with %CV and Concentration Difference
        for plate, wells in self.plate_format.items():
            for i, well in enumerate(wells):
                sample_type = well[:2]
                well_label = f"{plate}{well_counter[plate]}"
                well_counter[plate] += 1
                fi_endpoint = self.plate_data_y[plate][i]
                concentration_predicted = self.inverse_four_param_logistic(fi_endpoint, a, b, c, d) if None not in (a, b, c, d) else None
                concentration_actual = float(self.plate_data_x[plate][i]) if self.plate_data_x[plate][i] else None
                cv = cv_values.get(sample_type, None)
                if concentration_actual and concentration_actual != 0:
                    concentration_difference = ((concentration_predicted - concentration_actual) / concentration_actual) * 100
                else:
                    concentration_difference = None
                additional_table.append([sample_type, well_label, fi_endpoint, concentration_predicted, cv, concentration_difference])

        print("TABLE DATA")
        headers = ["Plate", "Well", "FI - EndPoint", "Predicted X", "%CV", "Concentration Difference (%)"]
        print(tabulate(additional_table, headers=headers, tablefmt="grid"))

# Example data
plate_data_x = {
    "A": [100, 50, 25, 12.5, 6.25, 0],
    "B": [100, 50, 25, 12.5, 6.25, 0],
    "C": [100, 50, 25, 12.5, 6.25, 0],
    "D": [0, 0, 0, 0, 0, 0]
}

plate_data_y = {
    "A": [640, 426, 256, 130, 70, 8],
    "B": [664, 400, 240, 134, 76, 10],
    "C": [635, 410, 244, 136, 73, 8],
    "D": [400, 421, 301, 290, 9, 9]
}

plate_format = {
    'A': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'B': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'C': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'D': ['U1', 'U1', 'U2', 'U2', 'X', 'B']
}

# Instantiate and run the analysis
FourPL_Weighted_analyzer = FourPL_Weighted(plate_data_x, plate_data_y, plate_format, weighted=True)
FourPL_Weighted_analyzer.coefficients()
FourPL_Weighted_analyzer.display_regression_metrics()
FourPL_Weighted_analyzer.generate_and_print_additional_table()
