import numpy as np
from collections import defaultdict
import concurrent.futures
import matplotlib.pyplot as plt
from lmfit import Model, Parameters
import pandas as pd

class FourParameterLogisticRegression:
    lld_constant = 1.645

    def __init__(self, data, weighted=False, forced_zero=False):
        self.data = data
        self.weighted = weighted
        self.forced_zero = forced_zero
        self.process_and_get_all_values()
        self.valid_points()
        self.params, self.r_squared = self.calculate_params_with_lmfit()

    def process_and_get_all_values(self):
        sum_x_standards = defaultdict(list)
        sum_y_standards = defaultdict(list)
        sum_y_blanks = []

        all_y_values, all_identifiers, all_wells, all_x_values = [], [], [], []

        for entry in self.data:
            identifier = entry['identifier'].upper()
            x = entry.get('x', [])
            y = entry['y']
            coordinates = entry['coordinates']

            all_y_values.extend(y)
            all_identifiers.extend([identifier] * len(y))
            all_wells.extend(coordinates if isinstance(coordinates, list) else [coordinates.upper()] * len(y))
            all_x_values.extend(x if x else [0] * len(y))

            if identifier.startswith('S'):
                sum_x_standards[identifier].extend(x)
                sum_y_standards[identifier].extend(y)
            elif identifier.startswith('B'):
                sum_y_blanks.extend(y)

        self.averages_standards = {
            identifier: {'avg_x': np.mean(x) if x else None, 'avg_y': np.mean(y)}
            for identifier, (x, y) in zip(sum_x_standards.keys(), zip(sum_x_standards.values(), sum_y_standards.values()))
        }
        self.avg_y_blanks = np.mean(sum_y_blanks) if sum_y_blanks else None
        self.x = [v['avg_x'] for v in self.averages_standards.values() if v['avg_x'] is not None]
        self.y = [v['avg_y'] for v in self.averages_standards.values()]
        self.blank_y = [self.avg_y_blanks] if self.avg_y_blanks is not None else []

        self.all_identifiers = all_identifiers
        self.all_wells = all_wells
        self.all_x_values = all_x_values
        self.all_y_values = all_y_values

    def valid_points(self):
        try:
            self.x_values = np.array([float(x) for x in self.x if x])
            self.y_values = np.array([y for y in self.y if y])
        except ValueError as e:
            raise ValueError(f"Error converting points to float: {e}")

    def four_param_logistic(self, x, a, b, c, d):
        return d + ((a - d) / (1 + (x / c) ** b))

    def calculate_params_with_lmfit(self):
        # Scaling x_values to [0, 1]
        min_x = np.min(self.x_values)
        max_x = np.max(self.x_values)
        scale_factor = max_x - min_x
        x_scaled = (self.x_values - min_x) / scale_factor

        model = Model(self.four_param_logistic)

        # Define initial guesses and parameter bounds
        params = Parameters()
        params.add('a', value=np.max(self.y_values), min=0)
        params.add('b', value=1, min=0)
        params.add('c', value=np.median(x_scaled), min=0)
        params.add('d', value=np.min(self.y_values), min=0)

        if self.forced_zero:
            params['a'].set(value=0, vary=False)

        # Calculate weights if weighted is True
        weights = None
        if self.weighted:
            weights = 1 / (self.y_values ** 2 + 1e-10)  # Adding small value to avoid division by zero

        result = model.fit(self.y_values, params, x=x_scaled, weights=weights)

        best_params = result.params.valuesdict()
        best_params['c'] = best_params['c'] * scale_factor + min_x

        r_squared = self.calculate_r_squared(list(best_params.values()), x_scaled)

        return list(best_params.values()), r_squared

    def calculate_r_squared(self, params, x_scaled):
        x_unique = np.unique(self.x_values)
        y_avg = np.array([np.mean(self.y_values[self.x_values == xi]) for xi in x_unique])
        y_pred = self.four_param_logistic(x_scaled, *params)
        rss = np.sum((y_avg - y_pred) ** 2)
        tss = np.sum((y_avg - np.mean(y_avg)) ** 2)
        r_squared = 1 - (rss / tss)
        return r_squared

    def pred(self, x_values):
        if self.forced_zero:
            b, c, d = self.params[1:]
            a = 0
        else:
            a, b, c, d = self.params
        return self.four_param_logistic(x_values, a, b, c, d)

    def calculate_metrics(self):
        try:
            if self.params is None:
                return None, None, None
            x_unique = np.unique(self.x_values)
            y_avg = np.array([np.mean(self.y_values[self.x_values == xi]) for xi in x_unique])
            y_pred = self.pred(x_unique)
            rss = np.sum((y_avg - y_pred) ** 2)
            tss = np.sum((y_avg - np.mean(y_avg)) ** 2)
            r_squared = 1 - (rss / tss)
            n = len(y_avg)
            adjusted_r_squared = 1 - ((1 - r_squared) * ((n - 1) / (n - 2))) if n > 2 else r_squared
            return r_squared, adjusted_r_squared, rss
        except Exception as e:
            raise ValueError(f"Error calculating metrics: {e}")

    def calculate_lld(self):
        try:
            if not self.blank_y:
                return None
            mean_blank = np.mean(self.blank_y)
            std_blank = np.std(self.blank_y, ddof=1) if len(self.blank_y) > 1 else 0
            min_concentration, max_concentration = np.min(self.x_values), np.max(self.x_values)
            curvemin = self.pred(min_concentration)
            curvemax = self.pred(max_concentration)
            if curvemax > curvemin:
                lld = (mean_blank + (2 * self.lld_constant * std_blank) - self.params[2]) / self.params[1]
            else:
                lld = (mean_blank - (2 * self.lld_constant * std_blank) - self.params[2]) / self.params[1]
            return lld
        except Exception as e:
            raise ValueError(f"Error calculating LLD: {e}")

    def calculate_cv(self, sample_type):
        try:
            if self.params is None:
                return None

            a, b, c, d = (0, *self.params[1:]) if self.forced_zero else self.params

            data = pd.DataFrame(self.data)
            sample_data = data[data['identifier'].str.startswith(sample_type)]

            if sample_data.empty:
                return None

            y_values = np.concatenate(sample_data['y'].values)

            def calculate_concentration(y):
                if (y != d) and (y - d != 0):
                    try:
                        concentration_predicted = c * (((a - d) / (y - d)) - 1) ** (1 / b)
                        if concentration_predicted >= 0 and not np.isinf(concentration_predicted):
                            return concentration_predicted
                    except (OverflowError, ZeroDivisionError, ValueError):
                        return None
                return None

            concentrations = np.array([calculate_concentration(y) for y in y_values])
            concentrations = concentrations[concentrations != np.array(None)]

            if len(concentrations) < 2:
                return None

            mean = np.mean(concentrations)
            std_dev = np.std(concentrations, ddof=1)
            return (std_dev / mean) * 100 if mean != 0 else 0
        except Exception as e:
            raise ValueError(f"Error calculating CV: {e}")


    def process_well(self, entry):
        sample_type = entry['identifier']
        x_value = entry['x'][0] if entry['x'] else None
        y_value = entry['y'][0]
        fi_endpoint = y_value
        if self.params is not None:
            if self.forced_zero:
                b, c, d = self.params[1:]
                a = 0
            else:
                a, b, c, d = self.params
            if (fi_endpoint != d) and (fi_endpoint - d != 0):  # Avoid division by zero
                try:
                    concentration_predicted = c * (((a - d) / (fi_endpoint - d)) - 1) ** (1 / b)
                    if concentration_predicted < 0 or np.isinf(concentration_predicted):
                        concentration_predicted = None
                except (OverflowError, ZeroDivisionError, ValueError):
                    concentration_predicted = None
            else:
                concentration_predicted = None
        else:
            concentration_predicted = None
        concentration_actual = float(x_value) if x_value else None
        cv = self.calculate_cv(sample_type)
        concentration_difference = (
            ((concentration_predicted - concentration_actual) / concentration_actual) * 100
            if concentration_predicted is not None and concentration_actual is not None
            else None
        )
        return {
            "coordinates": entry['coordinates'],
            "identifier": sample_type,
            "concentration_predicted": concentration_predicted,
            #"concentration_actual": concentration_actual, - only if required
            "% cv": cv,
            "concentration_difference": concentration_difference
        }

    def generate_additional_table(self):
        additional_table = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.process_well, entry) for entry in self.data]
            for future in concurrent.futures.as_completed(futures):
                additional_table.append(future.result())
        return additional_table

    def get_metrics(self):
        try:
            r_squared, adjusted_r_squared, rss = self.calculate_metrics()
            table_details = self.generate_additional_table()
            lld = self.calculate_lld()
            method_name = "4-Parameter Logistic Regression"
            if self.weighted:
                method_name += " Weighted"
            if self.forced_zero:
                method_name += " Forced Zero"
            metrics = {
                "method": method_name,
                "function": {
                    "a": 0 if self.forced_zero else self.params[0],
                    "b": self.params[1] if self.forced_zero else self.params[1],
                    "c": self.params[2] if self.forced_zero else self.params[2],
                    "d": self.params[3] if self.forced_zero else self.params[3],
                },
                "RSS": rss,
                "R_Squared": r_squared,
                "Adjusted_R_Squared": adjusted_r_squared,
                "LLD": lld,
                "Additional_Table_Details": table_details
            }
            return metrics
        except Exception as e:
            raise ValueError(f"Error getting metrics: {e}")

    def plot_fit(self):
        plt.figure(figsize=(10, 6))
        plt.scatter(self.x_values, self.y_values, label='Data', color='blue')
        x_fit = np.linspace(min(self.x_values), max(self.x_values), 500)
        y_fit = self.pred(x_fit)
        plt.plot(x_fit, y_fit, label='4PL Fit', color='red')
        plt.xlabel('X Values')
        plt.ylabel('Y Values')
        plt.title('4-Parameter Logistic Regression Fit')
        plt.legend()
        plt.show()

# Sample data for testing -1
data1 = [
    {"coordinates": ["A1"], "identifier": "S5", "x": [100], "y": [640]},
    {"coordinates": ["A2"], "identifier": "S4", "x": [50], "y": [426]},
    {"coordinates": ["A3"], "identifier": "S3", "x": [25], "y": [256]},
    {"coordinates": ["A4"], "identifier": "S2", "x": [12.5], "y": [130]},
    {"coordinates": ["A5"], "identifier": "S1", "x": [6.25], "y": [70]},
    {"coordinates": ["A6"], "identifier": "B", "x": [], "y": [8]},
    {"coordinates": ["B1"], "identifier": "S5", "x": [100], "y": [664]},
    {"coordinates": ["B2"], "identifier": "S4", "x": [50], "y": [400]},
    {"coordinates": ["B3"], "identifier": "S3", "x": [25], "y": [240]},
    {"coordinates": ["B4"], "identifier": "S2", "x": [12.5], "y": [134]},
    {"coordinates": ["B5"], "identifier": "S1", "x": [6.25], "y": [76]},
    {"coordinates": ["B6"], "identifier": "B", "x": [], "y": [10]},
    {"coordinates": ["C1"], "identifier": "S5", "x": [100], "y": [635]},
    {"coordinates": ["C2"], "identifier": "S4", "x": [50], "y": [410]},
    {"coordinates": ["C3"], "identifier": "S3", "x": [25], "y": [244]},
    {"coordinates": ["C4"], "identifier": "S2", "x": [12.5], "y": [136]},
    {"coordinates": ["C5"], "identifier": "S1", "x": [6.25], "y": [73]},
    {"coordinates": ["C6"], "identifier": "B", "x": [], "y": [8]},
    {"coordinates": ["D1"], "identifier": "U1", "x": [], "y": [400]},
    {"coordinates": ["D2"], "identifier": "U1", "x": [], "y": [421]},
    {"coordinates": ["D3"], "identifier": "U2", "x": [], "y": [301]},
    {"coordinates": ["D4"], "identifier": "U2", "x": [], "y": [290]},
    {"coordinates": ["D5"], "identifier": "X", "x": [], "y": [9]},
    {"coordinates": ["D6"], "identifier": "B", "x": [], "y": [9]},
]

# Sample data for testing -2
data2 = [
    {"coordinates": ["A1"], "identifier": "S12", "x": [0.000002480058], "y": [5823]},
    {"coordinates": ["A2"], "identifier": "S12", "x": [0.000002480058], "y": [6052]},
    {"coordinates": ["A3"], "identifier": "S8", "x": [0.000000030618], "y": [3698]},
    {"coordinates": ["A4"], "identifier": "S8", "x": [0.000000030618], "y": [3469]},
    {"coordinates": ["A5"], "identifier": "S4", "x": [0.000000000378], "y": [2029]},
    {"coordinates": ["A6"], "identifier": "S4", "x": [0.000000000378], "y": [1952]},
    {"coordinates": ["B1"], "identifier": "S11", "x": [0.000000826686], "y": [5554]},
    {"coordinates": ["B2"], "identifier": "S11", "x": [0.000000826686], "y": [6008]},
    {"coordinates": ["B3"], "identifier": "S7", "x": [0.000000010206], "y": [2652]},
    {"coordinates": ["B4"], "identifier": "S7", "x": [0.000000010206], "y": [2681]},
    {"coordinates": ["B5"], "identifier": "S3", "x": [0.000000000126], "y": [1661]},
    {"coordinates": ["B6"], "identifier": "S3", "x": [0.000000000126], "y": [2033]},
    {"coordinates": ["C1"], "identifier": "S10", "x": [0.000000275562], "y": [5302]},
    {"coordinates": ["C2"], "identifier": "S10", "x": [0.000000275562], "y": [5413]},
    {"coordinates": ["C3"], "identifier": "S6", "x": [0.000000003402], "y": [2116]},
    {"coordinates": ["C4"], "identifier": "S6", "x": [0.000000003402], "y": [2356]},
    {"coordinates": ["C5"], "identifier": "S2", "x": [0.000000000042], "y": [1611]},
    {"coordinates": ["C6"], "identifier": "S2", "x": [0.000000000042], "y": [1912]},
    {"coordinates": ["D1"], "identifier": "S9", "x": [0.000000091854], "y": [4629]},
    {"coordinates": ["D2"], "identifier": "S9", "x": [0.000000091854], "y": [4809]},
    {"coordinates": ["D3"], "identifier": "S5", "x": [0.000000001134], "y": [1760]},
    {"coordinates": ["D4"], "identifier": "S5", "x": [0.000000001134], "y": [1712]},
    {"coordinates": ["D5"], "identifier": "S1", "x": [0.000000000014], "y": [2011]},
    {"coordinates": ["D6"], "identifier": "S1", "x": [0.000000000014], "y": [1974]},
]


model = FourParameterLogisticRegression(data1, weighted=True, forced_zero=True)
metrics = model.get_metrics()
print(metrics)

model.plot_fit()

