import itertools
import numpy as np
import logging
from django.http import JsonResponse
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from lmfit import Model, Parameters
from calculation_methods.py_calculations.calculation_constants.constants import CUBICMODEL, EPSILON, FIVEPLMODEL, FOURPLMODEL, IQR_MAXSCALE, IQR_MINSCALE, LINEARMODEL, LLD_CONSTANT, NUM_POINTS
from calculation_methods.py_calculations.error_handler.error_messages import ErrorMessages
from scipy.interpolate import CubicSpline

# Constants for identifying dataset types
input_data_id = 'input_data'
standard_values_id = 'standard_values'
blank_values_id = 'blank_values'

# Logger setup for debugging and error tracking
log = logging.getLogger('django')

def calculate_r_squared(ss_res, ss_tot):
    """
    Calculate R-squared based on residual sum of squares and total sum of squares.
    If ss_tot is 0 (no variance in y-values), R-squared will be 0.

    Formula: R² = 1 - (SS_res / SS_tot)
    - SS_res: Sum of squared residuals (Σ(y_true - y_pred)²)
    - SS_tot: Total sum of squares (Σ(y_true - y_mean)²)

    Returns:
        float: R-squared value, constrained between 0 and 1.
    """
    if ss_tot == 0:
        return 0  # No variance in y-values, so R² is 0
    else:
        r_squared = 1 - (ss_res / ss_tot)
        # Ensure R-squared stays within valid range [0, 1]
        return max(0, min(r_squared, 1))

def calculate_n_p(y_true, coefficients):
    """
    Calculate the number of observations (n) and the number of predictors (p).
    - n: Number of data points in y_true.
    - p: Number of parameters in the model (length of coefficients).

    Returns:
        tuple: (n, p)
    """
    n = len(y_true)
    p = len(coefficients) if coefficients is not None else 0
    return n, p

def calculate_adjusted_r_squared(r_squared, n, p):
    """
    Calculate adjusted R-squared to account for the number of predictors.

    Formula: Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]
    - R²: Regular R-squared
    - n: Number of observations
    - p: Number of predictors

    Returns:
        float or None: Adjusted R-squared, or None if adjustment isn’t possible (n <= p + 1).
    """
    if r_squared is None or n <= p + 1:
        return None  # Cannot compute if insufficient degrees of freedom
    return 1 - ((1 - r_squared) * (n - 1) / (n - p - 1)) if n > p + 1 else r_squared

def extract_data_sets(data_sets):
    """
    Extract input_data, standard_values, and blank_values from a list of datasets.
    Raises an error if input_data or standard_values are missing.

    Returns:
        tuple: (input_data, standard_values, blank_values)
    """
    try:
        input_data = next((dataset.get('data') for dataset in data_sets if dataset.get('id') == input_data_id), None)
        if not input_data:
            raise handle_error(ErrorMessages.ERROR_MISSING_INPUT_DATA)
        standard_values = next((dataset.get('data') for dataset in data_sets if dataset.get('id') == standard_values_id), None)
        if not standard_values:
            raise handle_error(ErrorMessages.ERROR_MISSING_INPUT_DATA)
        blank_values = next((dataset.get('data') for dataset in data_sets if dataset.get('id') == blank_values_id), [])
        return input_data, standard_values, blank_values
    except StopIteration:
        raise handle_error(ErrorMessages.ERROR_MISSING_INPUT_DATA)

def compute_statistical_cv(data):
    """
    Calculate the coefficient of variation (CV) as a percentage.

    Formula: CV = (std_dev / mean) * 100
    - std_dev: Standard deviation of data (sample, ddof=1)
    - mean: Mean of data

    Args:
        data (list or np.ndarray): Predicted concentration values.

    Returns:
        float or None: CV in percentage, or None if insufficient data or mean is zero.
    """
    try:
        if len(data) < 2:
            return None  # Need at least 2 points for standard deviation
        
        mean = np.mean(data)
        std_dev = np.std(data, ddof=1)  # Sample standard deviation
        
        if np.isnan(std_dev):
            std_dev = 0  # Handle NaN as zero
        if np.isnan(mean):
            mean = 0  # Handle NaN as zero
        
        cv = (std_dev / mean) * 100 if mean != 0 and np.isfinite(std_dev) else None
        return cv
    except Exception as e:
        raise handle_error(ErrorMessages.ERROR_STATISTICAL_CALCULATION)

def handle_error(error_obj, *args, response=None):
    #Handle errors by logging them and either returning a response or raising a ValueError.

    # Extract the code and message
    error_code = error_obj.get("code", "UnknownError")
    message_template = error_obj.get("message", "Unknown error occurred.")
    message = message_template.format(*args) if args else message_template
    log.error(f"Error: {message}")
    if response:
        return response
    raise ValueError(error_code, message)

def process_input_data(input_data):
    """
    Process raw input data to compute averages for standards and blanks.

    Returns:
        dict: Processed data with averages and lists for x, y, identifiers, wells, etc.
    """
    try:
        sum_x_standards = defaultdict(list)
        sum_y_standards = defaultdict(list)
        sum_y_blanks = []
        all_y_values, all_identifiers, all_wells, all_x_values = [], [], [], []

        for entry in input_data:
            identifier = entry['identifier']
            x = entry.get('x', [])  # Default to empty list if x is missing
            y = entry['y']
            coordinates = entry['coordinates']
            all_y_values.extend(y)
            all_identifiers.extend([identifier] * len(y))
            all_wells.extend(coordinates if isinstance(coordinates, list) else [coordinates.upper()] * len(y))
            all_x_values.extend(x if x else [0] * len(y))  # Use 0 if x is missing
            if identifier.upper().startswith('S'):
                sum_x_standards[identifier].extend(x)
                sum_y_standards[identifier].extend(y)
            elif identifier.upper().startswith('B'):
                sum_y_blanks.extend(y)

        averages_standards = {
            identifier: {'avg_x': np.mean(x) if x else None, 'avg_y': np.mean(y)}
            for identifier, (x, y) in zip(sum_x_standards.keys(), zip(sum_x_standards.values(), sum_y_standards.values()))
        }
        avg_y_blanks = np.mean(sum_y_blanks) if sum_y_blanks else None

        return {
            'averages_standards': averages_standards,
            'avg_y_blanks': avg_y_blanks,
            'x': [v['avg_x'] for v in averages_standards.values() if v['avg_x'] is not None],
            'y': [v['avg_y'] for v in averages_standards.values()],
            'blank_y': [avg_y_blanks] if avg_y_blanks is not None else [],
            'all_identifiers': all_identifiers,
            'all_wells': all_wells,
            'all_x_values': all_x_values,
            'all_y_values': all_y_values
        }
    except Exception as e:
        handle_error(ErrorMessages.ERROR_PROCESS_VALUES)

def process_statistic_data_sets(data_set, request_method, target_source):
    # For statistics, data_set has 'id' and a 'data' array
    # For kinetic_statistics, each data_set is a single well

    if request_method == 'statistics':
        identifier = data_set.get('id')
        raw_data = data_set.get('data', [])
    else:  # kinetic_statistics
        identifier = data_set.get('coordinates')  # per-well
        raw_data = [data_set]  # wrap single-well into list

    # Flatten y-values
    flat_y_values = list(itertools.chain.from_iterable(
        np.array(d.get('y', [])).flatten() for d in raw_data
    ))

    return {
        target_source: identifier,
        'data': flat_y_values
    }

def validate_points(x, y, weighted=False):
    """
    Convert x and y to numpy arrays and optionally compute weights.

    Args:
        x (list): X-values (e.g., concentrations)
        y (list): Y-values (e.g., responses)
        weighted (bool): If True, compute weights as 1/(y² + ε)

    Returns:
        tuple: (x_values, y_values, weights)
    """
    try:
        x_values = np.array([float(x) for x in x if x is not None])
        y_values = np.array([y for y in y if y is not None])
        weights = np.array([1 / (y ** 2 + 1e-10) if weighted else 1 for y in y_values]) if weighted else None
        return x_values, y_values, weights
    except ValueError as e:
        handle_error(ErrorMessages.ERROR_CONVERTING_FLOAT)

def calculate_curve_y(x, model_type, params):
    """
    Calculate y-value for a given x based on the model type and parameters.

    Formulas:
    - Linear: y = a + b * x
    - 4PL: y = d + (a - d) / (1 + (|x|/c)^b)
    - 5PL: y = d + (a - d) / (1 + (x/c)^b)^g

    Args:
        x (float): Input x-value
        model_type (str): 'linear', 'cubic', '4pl', or '5pl'
        params (dict or list): Model parameters

    Returns:
        float or None: Calculated y-value or None if computation fails
    """
    try:
        
        if model_type == LINEARMODEL:
            #Calculates the y-value using the linear regression formula y = a + bx.
            a, b = params[0], params[1]
            return a + (b * x) if x is not None else None
        elif model_type == CUBICMODEL:
            for coeff in params:
                if coeff['x_start'] <= x <= coeff['x_end']:
                    a_i = coeff['a']
                    b_i = coeff['b']
                    c_i = coeff['c']
                    d_i = coeff['d']
                    x_i = coeff['x_start']
                    dx = x - x_i
                    # Cubic formula: y = a + b*(x-x_i) + c*(x-x_i)² + d*(x-x_i)³
                    y_value = a_i + (b_i * dx) + (c_i * dx**2) + (d_i * dx**3)
                    return y_value
            return None  # x outside defined segments
        elif model_type == FOURPLMODEL:
            a, b, c, d = params['a'], params['b'], params['c'], params['d']
            x_abs = np.abs(x)
            denominator = np.maximum(c + EPSILON, 0)  # Avoid division by zero
            # 4PL formula: y = d + (a - d) / (1 + (|x|/c)^b)
            return d + ((a - d) / (1 + (x_abs / denominator) ** b))
        elif model_type == FIVEPLMODEL:
            a, b, c, d, g = params['a'], params['b'], params['c'], params['d'], params['g']
            x = np.maximum(x, EPSILON)  # Ensure x > 0
            # 5PL formula: y = d + (a - d) / (1 + (x/c)^b)^g
            return d + ((a - d) / (1 + (x / (c + EPSILON)) ** b) ** g)
        else:
            raise handle_error(ErrorMessages.ERROR_INCORRECT_MODEL)
    except Exception as e:
        handle_error(ErrorMessages.ERROR_CALCULATE_CURVE_Y)
        return None

def generate_table_by_sample_groups(**sgargs):
    """
    Generate a table by grouping input sample data and computing metrics like predicted x-values and CV.

    Args (via kwargs):
        input_data: List of sample data dictionaries
        model_type: Curve model type ('4pl', '5pl', etc.)
        params: Model parameters
        predict_func: Function to predict x from y
        cv_func: Function to compute CV

    Returns:
        list: Processed sample group data
    """
    try:
        input_data = sgargs.get('input_data')
        model_type = sgargs.get('model_type')
        params = sgargs.get('params')
        predict_func = sgargs.get('predict_func')
        cv_func = sgargs.get('cv_func')

        grouped_data = defaultdict(list)
        for item in input_data:
            grouped_data[item['identifier']].append(item)

        def process_group(entry):
            identifier = entry[0]['identifier']
            coordinates = [item['coordinates'] for item in entry]
            x_values = entry[0]['x'] if entry[0]['x'] else [0]
            y_values = [item['y'][0] for item in entry]
            blank_y = [b['y'] for b in entry if b['identifier'].startswith('B')]
            mean_blank_y = np.mean(blank_y) if blank_y else 0
            curve_y = []
            y_corrected = [y - mean_blank_y for y in y_values]  # Subtract blank mean
            x_predicted = [predict_func(y) for y in y_values]    # Predict x from y
            has_x = len(x_values) > 0 and x_values[0] is not None
            x_difference = [
                x_values[0] - x_pred if has_x and x_pred is not None else None
                for x_pred in x_predicted
            ]
            cv = cv_func(identifier)
            if identifier.upper().startswith('S'):  # Only for standards
                curve_y.append(calculate_curve_y(x_values[0], model_type, params))

            return {
                'identifier': identifier,
                'coordinates': coordinates,
                'y': y_values,
                'y_corrected': y_corrected,
                'x': x_values,
                'x_predicted': x_predicted,
                'x_difference': x_difference,
                'cv': cv,
                'curve_y': curve_y
            }

        with ThreadPoolExecutor() as executor:
            return list(executor.map(process_group, grouped_data.values()))
    except Exception as e:
        handle_error(ErrorMessages.ERROR_GENERATE_TABLE)

def calculate_statistics(**statargs):
    """
    Calculate statistical metrics: RSS, MSE, R², Adjusted R², and SYX.

    Formulas:
    - RSS = Σ(y_true - y_pred)²
    - MSE = RSS / n
    - SYX = √(RSS / (n - 2))
    - R² = 1 - (RSS / SS_tot)
    - Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]

    Args (via kwargs):
        y_true: Observed y-values
        y_pred: Predicted y-values
        coefficients: Model parameters

    Returns:
        dict: Statistical metrics
    """
    try:
        y_true = statargs.get('y_true')
        y_pred = statargs.get('y_pred')
        coefficients = statargs.get('coefficients')
        
        ss_res = np.sum((y_true - y_pred) ** 2)  # Residual sum of squares
        ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)  # Total sum of squares
        r_squared = calculate_r_squared(ss_res, ss_tot)
        n, p = calculate_n_p(y_true, coefficients)
        adjusted_r_squared = calculate_adjusted_r_squared(r_squared, n, 1)
        mse = ss_res / n if n > 0 else None  # Mean squared error
        syx = np.sqrt(ss_res / (n - 2)) if n > 2 else None  # Standard error of estimate

        return {
            'RSS': ss_res,
            'MSE': mse,
            'SS_Total': ss_tot,
            'R_Squared': r_squared,
            'Adjusted_R_Squared': adjusted_r_squared,
            'SYX': syx
        }
    except Exception as e:
        return handle_error(ErrorMessages.ERROR_STATISTICAL_CALCULATION)

def calculate_lld(**lldargs):
    """
    Calculate the lower limit of detection (LLD) based on blank values and curve fit.

    Formula (simplified):
    - LLD = (mean_blank ± 2 * LLD_constant * std_blank - curvemin) / curve_diff
    - curve_diff = curvemax - curvemin
    - If curvemax > curvemin, use +, else use -

    Args (via kwargs):
        blank_y: Blank response values
        model_type: Curve model type
        params: Model parameters
        x_values: X-values for curve range

    Returns:
        float or None: LLD value or None if not computable
    """
    try:
        blank_y = lldargs.get('blank_y')
        model_type = lldargs.get('model_type').lower()
        params = lldargs.get('params')
        x_values = lldargs.get('x_values')

        if not blank_y:
            return None  # No blank data, LLD cannot be computed
        mean_blank = np.mean(blank_y)
        std_blank = np.std(blank_y) if len(blank_y) > 1 else 0  # Use 0 if only one blank
        min_x, max_x = np.min(x_values), np.max(x_values)
        curvemin = calculate_curve_y(min_x, model_type, params)  # Y at min x
        curvemax = calculate_curve_y(max_x, model_type, params)  # Y at max x
        curve_diff = curvemax - curvemin  # Slope direction
        lld_constant = LLD_CONSTANT  # Typically 2 or 3, from constants
        if curve_diff == 0:
            return None  # Avoid division by zero
        if curvemax > curvemin:
            # For increasing curves: LLD = (mean_blank + 2 * k * std_blank - curvemin) / curve_diff
            return (mean_blank + (2 * lld_constant * std_blank) - curvemin) / curve_diff
        else:
            # For decreasing curves: LLD = (mean_blank - 2 * k * std_blank - curvemin) / curve_diff
            return (mean_blank - (2 * lld_constant * std_blank) - curvemin) / curve_diff
    except Exception as e:
        handle_error(ErrorMessages.ERROR_CALCULATE_LLD)

def calculate_params_with_lmfit(**lmfitargs):
    """
    Fit 4PL or 5PL model parameters using lmfit’s non-linear least squares.

    Formulas:
    - 4PL: y = d + (a - d) / (1 + (|x|/c)^b)
    - 5PL: y = d + (a - d) / (1 + (x/c)^b)^g

    Args (via kwargs):
        x_values: X-values (e.g., concentrations)
        y_values: Y-values (e.g., responses)
        model_type: '4pl' or '5pl'
        weighted: Boolean to use weights
        forced_to_zero: Boolean to fix a=0
        weights: Optional weights array

    Returns:
        list: Fitted parameters [a, b, c, d] for 4PL, [a, b, c, d, g] for 5PL
    """
    try:
        x_values = lmfitargs.get('x_values')
        y_values = lmfitargs.get('y_values')
        model_type = lmfitargs.get('model_type', '').lower()
        weighted = lmfitargs.get('weighted', False)
        forced_to_zero = lmfitargs.get('forced_to_zero', False)
        weights = lmfitargs.get('weights', None)

        # Validate inputs for finiteness
        if not np.all(np.isfinite(x_values)) or not np.all(np.isfinite(y_values)):
            handle_error(ErrorMessages.ERROR_INVALID_DATA)

        # Ensure minimum data points: 4 for 4PL, 5 for 5PL
        min_points = 4 if model_type == FOURPLMODEL else 5 if model_type == FIVEPLMODEL else 0
        if len(x_values) < min_points:
            handle_error(ErrorMessages.ERROR_INSUFFICIENT_DATA_POINTS)
        # Prepare data: Avoid zero/negative x-values
        x_fit = np.maximum(x_values, EPSILON)
        y_fit = y_values

        # Determine curve direction for initial guesses
        is_increasing = y_fit[-1] > y_fit[0]

        # Define model based on type
        if model_type == FOURPLMODEL:
            def model_func(x, a, b, c, d):
                return calculate_curve_y(x, FOURPLMODEL, {'a': a, 'b': b, 'c': c, 'd': d})
        elif model_type == FIVEPLMODEL:
            def model_func(x, a, b, c, d, g):
                return calculate_curve_y(x, FIVEPLMODEL, {'a': a, 'b': b, 'c': c, 'd': d, 'g': g})
        else:
            raise handle_error(ErrorMessages.ERROR_INCORRECT_MODEL(str(model_type)))
        model = Model(model_func)

        # Common parameters for both 4PL and 5PL
        params = Parameters()
        params.add('a', value=0 if forced_to_zero else (np.min(y_fit) if is_increasing else np.max(y_fit)), min=0 if forced_to_zero else -np.inf,vary=not forced_to_zero)
        params.add('b', value=1.0 if is_increasing else -1.0, min=-10, max=10)
        params.add('c', value=np.median(x_fit), min=EPSILON)
        params.add('d', value=np.max(y_fit) if is_increasing else np.min(y_fit), min=-np.inf)

        # Add 'g' only for 5PL
        if model_type == FIVEPLMODEL:
            params.add('g', value=1.0, min=0.1, max=10)

        # Fit the model
        weights = weights if weighted else None
        result = model.fit(y_fit, params, x=x_fit, weights=weights, nan_policy='omit')

        # Extract fitted parameters
        fitted_params = result.params.valuesdict()
        return [fitted_params['a'], fitted_params['b'], fitted_params['c'], fitted_params['d']] + ([fitted_params['g']] if model_type == FIVEPLMODEL else [])

    except Exception as e:
        handle_error(ErrorMessages.ERROR_CALCULATE_PARAMS)
def generate_interpolated_curve_data(**interpolation_args):
    """
    Generate smooth interpolated curve data points with dynamic range extension
    based on Interquartile Range (IQR) for standard curves
    (like 4PL, 5PL, cubic, and linear).
 
    Parameters:
    - x_data: List of X-axis values (input data points)
    - y_data: List of Y-axis values (input data points)
    - curve_type: Type of curve to generate (4PL, 5PL, cubic, linear)
    - coefficients: Coefficients for curve calculation (if applicable)
 
    Returns:
    - List of dictionaries containing interpolated X and Y coordinates.
    """
    try:
        x_data = interpolation_args.get('x_data')
        y_data = interpolation_args.get('y_data')
        curve_type = interpolation_args.get('curve_type', '').lower()
        coefficients = interpolation_args.get('coefficients', None)
        
        # Sort data for consistent interpolation
        sorted_x, sorted_y = np.sort(x_data), np.array(y_data)[np.argsort(x_data)]

        # Extend range using IQR for smoother curve
        Q1, Q3 = np.percentile(sorted_x, [25, 75])
        IQR = Q3 - Q1
        step_size = np.diff(sorted_x).min()  # Minimum step between points
        lower_ext = max(0, sorted_x[0] - step_size * IQR_MINSCALE)  # Lower bound, not negative
        upper_ext = sorted_x[-1] + max(step_size * IQR_MAXSCALE, IQR * IQR_MAXSCALE)  # Upper bound
        
        # Generate evenly spaced points, including original min/max
        total_data_points = max(NUM_POINTS, len(sorted_x))
        interpolated_x_data = np.linspace(lower_ext, upper_ext, total_data_points)
        interpolated_x_data = np.append(interpolated_x_data, [min(sorted_x), max(sorted_x)])
        sorted_interpolated_x_data = np.sort(interpolated_x_data)

        # Compute interpolated y-values based on curve type
        if curve_type == CUBICMODEL:
            # Use cubic spline interpolation with natural boundary conditions
            y_interpolated = CubicSpline(sorted_x, sorted_y, bc_type='natural', extrapolate=True)(sorted_interpolated_x_data)
        elif curve_type in [FOURPLMODEL,FIVEPLMODEL,LINEARMODEL]:
            # Use model equation for y-values
            y_interpolated = np.array([calculate_curve_y(x, curve_type, coefficients) for x in sorted_interpolated_x_data])
        else:
            raise handle_error(ErrorMessages.ERROR_INCORRECT_MODEL)

        # Return points, ensuring y >= 0
        return [{'x_coordinate': float(x), 'y_coordinate': float(max(y, 0))} for x, y in zip(sorted_interpolated_x_data, y_interpolated)]
    except Exception as e:
        handle_error(ErrorMessages.ERROR_GENERATE_CURVE_POINTS)
        return []
