import numpy as np
from lmfit import Model, Parameters
from collections import defaultdict
import concurrent.futures
import pandas as pd

class FiveParameterLogisticRegression:
    lld_constant = 1.645

    def __init__(self, data_sets, weighted=False, forced_to_zero=False):
        try:
            if not isinstance(data_sets, list):
                raise ValueError("The 'data_sets' field should be a list of dictionaries.")
            
            self.input_data = next(item['data'] for item in data_sets if item['id'] == 'input_data')
            self.standard_values = next(item['data'] for item in data_sets if item['id'] == 'standard_values')
            self.blank_values = next(item['data'] for item in data_sets if item['id'] == 'blank_values')
        except KeyError as e:
            raise ValueError(f"Missing expected key: {e}")
        except StopIteration:
            raise ValueError("Missing one of the expected data sets ('input_data', 'standard_values', or 'blank_values').")

        self.weighted = weighted
        self.forced_to_zero = forced_to_zero
        self.process_and_get_all_values()
        self.valid_points()
        self.params, self.r_squared = self.calculate_params_with_lmfit() 
        print(self.params)
        if self.params is None:
            raise ValueError("Failed to calculate parameters with lmfit.")
        self.coefficients = self.extract_coefficients()
        
    def process_and_get_all_values(self):
        sum_x_standards = defaultdict(list)
        sum_y_standards = defaultdict(list)
        sum_y_blanks = []

        all_y_values, all_identifiers, all_wells, all_x_values = [], [], [], []

        for entry in self.input_data:
            identifier = entry['identifier'].upper()
            x = entry.get('x', [])
            y = entry['y']
            coordinates = entry['coordinates']

            all_y_values.extend(y)
            all_identifiers.extend([identifier] * len(y))
            all_wells.extend(coordinates if isinstance(coordinates, list) else [coordinates.upper()] * len(y))
            all_x_values.extend(x if x else [0] * len(y))

            if identifier.startswith('S'):
                sum_x_standards[identifier].extend(x)
                sum_y_standards[identifier].extend(y)
            elif identifier.startswith('B'):
                sum_y_blanks.extend(y)

        self.averages_standards = {
            identifier: {'avg_x': np.mean(x) if x else None, 'avg_y': np.mean(y)}
            for identifier, (x, y) in zip(sum_x_standards.keys(), zip(sum_x_standards.values(), sum_y_standards.values()))
        }
        self.avg_y_blanks = np.mean(sum_y_blanks) if sum_y_blanks else None
        self.x = [v['avg_x'] for v in self.averages_standards.values() if v['avg_x'] is not None]
        self.y = [v['avg_y'] for v in self.averages_standards.values()]
        self.blank_y = [self.avg_y_blanks] if self.avg_y_blanks is not None else []

        self.all_identifiers = all_identifiers
        self.all_wells = all_wells
        self.all_x_values = all_x_values
        self.all_y_values = all_y_values

    def valid_points(self):
        try:
            self.x_values = np.array([float(x) for x in self.x if x])
            self.y_values = np.array([y for y in self.y if y])
        except ValueError as e:
            raise ValueError(f"Error converting points to float: {e}")

    def five_param_logistic(self, x, a, b, c, d, g):
        return d + ((a - d) / ((1 + (x / c) ** b) ** g))

    def calculate_params_with_lmfit(self):
        min_x = np.min(self.x_values)
        max_x = np.max(self.x_values)
        scale_factor = max_x - min_x
        x_scaled = (self.x_values - min_x) / scale_factor

        model = Model(self.five_param_logistic)

        params = Parameters()
        params.add('a', value=np.max(self.y_values), min=0)
        params.add('b', value=1, min=0)
        params.add('c', value=np.median(x_scaled), min=0)
        params.add('d', value=np.min(self.y_values), min=0)
        params.add('g', value=1, min=0)  # Add the fifth parameter g

        if self.forced_to_zero:
            params['a'].set(value=0, vary=False)

        weights = None
        if self.weighted:
            weights = 1 / (self.y_values ** 2 + 1e-10)

        result = model.fit(self.y_values, params, x=x_scaled, weights=weights)

        best_params = result.params.valuesdict()
        best_params['c'] = best_params['c'] * scale_factor + min_x

        r_squared = self.calculate_r_squared(list(best_params.values()), x_scaled)

        return list(best_params.values()), r_squared
    
    def extract_coefficients(self):
        try:
            # Ensure params is correctly set and has valid values
            if not hasattr(self, 'params') or self.params is None:
                raise ValueError("Params attribute is not set or is None.")
            if not isinstance(self.params, list):
                raise ValueError("Params should be a list.")
            if len(self.params) < 5:
                raise ValueError("Insufficient parameters to extract coefficients.")

            # Extract coefficients
            def extract_single_coefficient(i):
                a, b, c, d, g = self.params if not self.forced_to_zero else (0, *self.params[1:])
                coeff = {
                    'a': a,
                    'b': b,
                    'c': c,
                    'd': d,
                    'g': g,  # Include the fifth parameter g
                    'x_start': self.x_values[i],
                    'x_end': self.x_values[i + 1],
                }
                identifier = next(
                    (entry['identifier'] for entry in self.standard_values if entry['x'] == [coeff['x_end']]), 
                    None
                )
                coeff['identifier'] = identifier
                return coeff

            # Use ThreadPoolExecutor to parallelize coefficient extraction
            with concurrent.futures.ThreadPoolExecutor() as executor:
                coefficients_table = list(executor.map(extract_single_coefficient, range(len(self.x_values) - 1)))

            return coefficients_table

        except Exception as e:
            raise ValueError(f"Error in extract_coefficients: {e}")

    def calculate_r_squared(self, params, x_scaled):
        try:
            x_unique = np.unique(self.x_values)
            y_avg = np.array([np.mean(self.y_values[self.x_values == xi]) for xi in x_unique])
            y_pred = self.five_param_logistic(x_scaled, *params)
            rss = np.sum((y_avg - y_pred) ** 2)
            tss = np.sum((y_avg - np.mean(y_avg)) ** 2)
            r_squared = 1 - (rss / tss)
            return r_squared
        except Exception as e:
            raise ValueError(f"Error in calculate_r_squared: {e}")

    def predict(self, x_values):
        try:
            a, b, c, d, g = self.params if not self.forced_to_zero else (0, *self.params[1:])
            return self.five_param_logistic(x_values, a, b, c, d, g)
        except Exception as e:
            raise ValueError(f"Error in predict_coefficients: {e}")

    def calculate_cv(self, sample_type):
        try:
            if self.params is None:
                return None

            a, b, c, d, g = self.params if not self.forced_to_zero else (0, *self.params[1:])
            self.a, self.b, self.c, self.d, self.g = a, b, c, d, g

            data = pd.DataFrame(self.input_data)
            sample_data = data[data['identifier'].str.startswith(sample_type)]

            if sample_data.empty:
                return None

            y_values = np.concatenate(sample_data['y'].values)

            concentrations = np.array([self.calculate_concentration(y) for y in y_values])
            concentrations = concentrations[concentrations != np.array(None)]

            if len(concentrations) < 2:
                return None

            mean = np.mean(concentrations)
            std_dev = np.std(concentrations, ddof=1)
            return (std_dev / mean) * 100 if mean != 0 else 0
        except Exception as e:
            raise ValueError(f"Error calculating CV: {e}")
        
    def calculate_concentration(self, y):
        try:
            if y != self.d and y - self.d != 0:
                concentration_predicted = self.c * (((self.a - self.d) / (y - self.d)) ** (1 / self.g) - 1) ** (1 / self.b)
                if concentration_predicted >= 0 and not np.isinf(concentration_predicted) and not np.isnan(concentration_predicted):
                    return concentration_predicted
        except (OverflowError, ZeroDivisionError, ValueError):
            pass
        return None

    def calculate_statistics(self):
        try:
            y_true = self.y_values
            y_pred = self.five_param_logistic(self.x_values, *self.params)
            ss_res = np.sum((y_true - y_pred) ** 2)
            ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
            r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else None
            n = len(y_true)
            p = len(self.coefficients)
            adjusted_r_squared = 1 - ((1 - r_squared) * ((n - 1) / (n - 2))) if n > 2 else r_squared
            #adjusted_r_squared = 1 - ((1 - r_squared) * (n - 1) / (n - p - 1)) if r_squared is not None and n > p + 1 else None
            mse = ss_res / n if n > 0 else None
            syx = np.sqrt(ss_res / (n - 2)) if n > 2 else None
            return {
                "RSS": ss_res,
                "MSE": mse if mse is not None else "NA",
                "SS_Total": ss_tot,
                "R_squared": r_squared if r_squared is not None else "NA",
                "Adjusted_R_squared": adjusted_r_squared if adjusted_r_squared is not None else "NA",
                "SYX": syx if syx is not None else "NA"
            }
        except Exception as e:
            return {
                "RSS": "Error",
                "MSE": "Error",
                "SS_Total": "Error",
                "R_squared": "Error",
                "SYX": "Error"
            }

    def get_lld(self):
        try:
            self.blank_y = [blank_entry['y'] for blank_entry in self.blank_values if blank_entry['identifier'].startswith('B')]
            if not self.blank_y:
                return None
            mean_blank = np.mean(self.blank_y)
            std_blank = np.std(self.blank_y, ddof=1) if len(self.blank_y) > 1 else 0
            min_concentration, max_concentration = np.min(self.x_values), np.max(self.x_values)
            curvemin = self.predict(min_concentration)
            curvemax = self.predict(max_concentration)

            if curvemax > curvemin:
                lld = (mean_blank + (2 * self.lld_constant * std_blank) - curvemin) / (curvemax - curvemin)
            else:
                lld = (mean_blank - (2 * self.lld_constant * std_blank) - curvemin) / (curvemax - curvemin)
            return lld
        except Exception as e:
            raise ValueError(f"Error calculating LLD: {e}")

    def generate_additional_table_by_sample_groups(self):
        try:
            input_data = self.input_data
            grouped_standard_data = defaultdict(list)
            for item in input_data:
                grouped_standard_data[item['identifier']].append(item)
            
            grouped_list = list(grouped_standard_data.values())
            target_data = []
            for entry in grouped_list:
                result = self.group_and_calculate_metrics(entry)
                target_data.append(result)
                
            return target_data
        except Exception as e:
            raise ValueError(f"Error in generate_additional_table_by_sample_groups: {e}")

    def group_and_calculate_metrics(self, entry):
        try:
            identifier = ''
            coordinates = []
            x_values = []
            y_values = []
            cv = []
            blank_y_values = [blank_entry['y'] for blank_entry in self.blank_values if blank_entry['identifier'].startswith('B')]
            mean_blank_y = np.mean(blank_y_values) if blank_y_values else 0

            y_corrected = []
            concentration_predicted = []
            concentration_difference = None

            for item in entry:
                identifier = item['identifier']
                coordinates.append(item['coordinates'])
                x_values = item['x']
                y_value = item['y']
                y_values.append(item['y'][0])
                y_corrected.append(item['y'][0] - mean_blank_y)

            cv = self.calculate_cv(identifier)
            concentration_predicted = np.mean([self.calculate_concentration(y) for y in y_values if self.calculate_concentration(y) is not None])
            concentration_difference = x_values[0] - concentration_predicted if len(x_values) == 1 and concentration_predicted is not None else None

            response = {
                "identifier": identifier,
                "coordinates": coordinates,
                "y": y_values,
                "y_corrected": y_corrected,
                "x": x_values,
                "x_predicted": concentration_predicted,
                "x_difference": concentration_difference,
                "cv": cv
            }
            return response
        except Exception as e:
            raise ValueError(f"Error in group_and_calculate_metrics: {e}")

    def get_metrics(self):
        try:
            method_name = "5-Parameter Logistic Regression"
            additional_table_details = self.generate_additional_table_by_sample_groups()
            metrics = self.calculate_statistics()
            return {
                "method": method_name,
                "function": {"a": self.a, "b": self.b,"c": self.c, "d":self.d, "g":self.g},
                "metrics": metrics,
                "Additional_Table_Details": additional_table_details
            }
        except Exception as e:
            raise ValueError(f"Error in get_coefficients: {e}")

sample_data1 = [
    {
        'id': 'input_data',
        'data': [
            {"coordinates": ["A1"], "identifier": "S5", "x": [100], "y": [640]},
            {"coordinates": ["A2"], "identifier": "S4", "x": [50], "y": [426]},
            {"coordinates": ["A3"], "identifier": "S3", "x": [25], "y": [256]},
            {"coordinates": ["A4"], "identifier": "S2", "x": [12.5], "y": [130]},
            {"coordinates": ["A5"], "identifier": "S1", "x": [6.25], "y": [70]},
            {"coordinates": ["B1"], "identifier": "S5", "x": [100], "y": [664]},
            {"coordinates": ["B2"], "identifier": "S4", "x": [50], "y": [400]},
            {"coordinates": ["B3"], "identifier": "S3", "x": [25], "y": [240]},
            {"coordinates": ["B4"], "identifier": "S2", "x": [12.5], "y": [134]},
            {"coordinates": ["B5"], "identifier": "S1", "x": [6.25], "y": [76]},
            {"coordinates": ["C1"], "identifier": "S5", "x": [100], "y": [635]},
            {"coordinates": ["C2"], "identifier": "S4", "x": [50], "y": [410]},
            {"coordinates": ["C3"], "identifier": "S3", "x": [25], "y": [244]},
            {"coordinates": ["C4"], "identifier": "S2", "x": [12.5], "y": [136]},
            {"coordinates": ["C5"], "identifier": "S1", "x": [6.25], "y": [73]},
        ]
    },
    {
        'id': 'standard_values',
        'data': [
            {"coordinates": ["D1"], "identifier": "S1", "x": [], "y": [400]},
            {"coordinates": ["D2"], "identifier": "S1", "x": [], "y": [421]},
            {"coordinates": ["D3"], "identifier": "S2", "x": [], "y": [301]},
            {"coordinates": ["D4"], "identifier": "S2", "x": [], "y": [290]},
        ]
    },
    {
        'id': 'blank_values',
        'data': [
            {"coordinates": ["A6"], "identifier": "B", "x": [], "y": [8]},
            {"coordinates": ["B6"], "identifier": "B", "x": [], "y": [10]},
            {"coordinates": ["C6"], "identifier": "B", "x": [], "y": [8]},
            {"coordinates": ["D5"], "identifier": "X", "x": [], "y": [9]},
            {"coordinates": ["D6"], "identifier": "B", "x": [], "y": [9]},
        ]
    }
]


model = FiveParameterLogisticRegression(data_sets=sample_data1, weighted=True, forced_to_zero=False)
metrics = model.get_metrics()
print(metrics)
