import numpy as np
#from scipy.interpolate import CubicSpline
import matplotlib.pyplot as plt
from tabulate import tabulate

class CubicSplineInterpolator:
    def __init__(self, plate_data_x, plate_data_y, plate_format):
        # Calculate unique average values for x and y based on plate format
        self.X_unique, self.Y_unique = self.calculate_unique_values(plate_data_x, plate_data_y, plate_format)
        
        # Added origin (0, 0) to the data points
        self.X_unique = np.insert(self.X_unique, 0, 0)
        self.Y_unique = np.insert(self.Y_unique, 0, 0)
        
        # Ensure X values are sorted in increasing order
        if not np.all(np.diff(self.X_unique) >= 0):
            raise ValueError("X values must be sorted in increasing order.")
        
        # Natural cubic spline with zero second derivatives at the boundaries
        #self.cs = CubicSpline(self.X_unique, self.Y_unique, bc_type='natural')
        
        # Store coefficients and other data
        self.coefficients = self.cs.c.T
        self.plate_format = plate_format
        self.plate_data_y = plate_data_y

    def calculate_unique_values(self, plate_data_x, plate_data_y, plate_format):
        # Calculate average x values for each slot
        def calculate_average_x(plate_data_x, plate_format):
            average_x_values = {}
            all_s_slots = set()
            
            # Identify all unique 'S' slots
            for slots in plate_format.values():
                all_s_slots.update([slot for slot in slots if slot.startswith('S')])
                
            # Calculate the average x value for each slot
            for slot in all_s_slots:
                x_values = []
                for key, slots in plate_format.items():
                    if slot in slots:
                        indices = [i for i, s in enumerate(slots) if s == slot]
                        x_values.extend([plate_data_x[key][i] for i in indices])
                average_x_values[slot] = np.mean(x_values)
            return average_x_values

        # Calculate average y values for each slot
        def calculate_average_y(plate_data_y, plate_format):
            average_y_values = {}
            all_s_slots = set()
            
            # Identify all unique 'S' slots
            for slots in plate_format.values():
                all_s_slots.update([slot for slot in slots if slot.startswith('S')])
                
            # Calculate the average y value for each slot
            for slot in all_s_slots:
                y_values = []
                for key, slots in plate_format.items():
                    if slot in slots:
                        indices = [i for i, s in enumerate(slots) if s == slot]
                        y_values.extend([plate_data_y[key][i] for i in indices])
                average_y_values[slot] = np.mean(y_values)
            return average_y_values

        # Get average x and y values
        average_x_values = calculate_average_x(plate_data_x, plate_format)
        average_y_values = calculate_average_y(plate_data_y, plate_format)

        # Create unique X and Y arrays sorted by x values
        X_unique = np.array([average_x_values[slot] for slot in sorted(average_x_values.keys(), key=lambda x: average_x_values[x])])
        Y_unique = np.array([average_y_values[slot] for slot in sorted(average_y_values.keys(), key=lambda x: average_x_values[x])])

        return X_unique, Y_unique

    def extract_coefficients(self):
        # Extract and print the coefficients for each segment of the spline
        def extract_segment_coeffs(segment):
            a = self.cs.c[3, segment]
            b = self.cs.c[2, segment]
            c = self.cs.c[1, segment]
            d = self.cs.c[0, segment]
            x_start = self.X_unique[segment]
            x_end = self.X_unique[segment + 1]
            return x_start, x_end, a, b, c, d

        print("Extracted coefficients:")
        for i in range(len(self.X_unique) - 1):
            x_start, x_end, a, b, c, d = extract_segment_coeffs(i)
            print(f"x_{i+1} = {x_start}")
            print(f"x_{i+2} = {x_end}")
            print(f"a_{i+1} = {a}")
            print(f"b_{i+1} = {b}")
            print(f"c_{i+1} = {c}")
            print(f"d_{i+1} = {d}")
            print()

    def predict_x(self, y_val):
        # Predict x value for a given y value using the spline
        for i in range(len(self.X_unique) - 1):
            d, c, b, a = self.coefficients[i]
            coeffs = [d, c, b, a - y_val]
            roots = np.roots(coeffs)
            real_roots = [root.real for root in roots if np.isreal(root) and self.X_unique[i] <= root.real <= self.X_unique[i + 1]]
            if real_roots:
                return real_roots[0]
        return None

    def predict_and_print_x_values(self):
        # Predict and print x values for each y in plate_data_y
        print("\nPredicted x values for each y in plate_data_y:")
        additional_table = []
        all_actual_x = []
        all_actual_y = []
        all_predicted_x = []
        all_predicted_y = []

        for plate, wells in self.plate_format.items():
            y_values = self.plate_data_y[plate]
            for i, well in enumerate(wells):
                y_val = y_values[i]
                all_actual_x.append(plate_data_x[plate][i])
                all_actual_y.append(y_val)
                x_pred = self.predict_x(y_val)
                if x_pred is not None:
                    all_predicted_x.append(x_pred)
                    all_predicted_y.append(y_val)
                    cv = self.calculate_cv(plate)
                    concentration_difference = ((x_pred - y_val) / y_val) * 100
                    additional_table.append([well, plate + str(i + 1), y_val, x_pred, cv, concentration_difference])
                else:
                    additional_table.append([well, plate + str(i + 1), y_val, '', '', ''])

        headers = ["Wells", "Plate", "FI - EndPoint", "CONCENTRATION", "%CV", "CONCENTRATION DIFFERENCE (%)"]
        print(tabulate(additional_table, headers=headers, tablefmt="grid"))

    def calculate_cv(self, plate):
        # Calculate the coefficient of variation (%CV) for the given plate
        concentrations = []
        y_values = self.plate_data_y[plate]
        for y_val in y_values:
            concentration_predicted = self.predict_x(y_val)
            if concentration_predicted is not None:
                concentrations.append(concentration_predicted)
        if len(concentrations) < 2:
            return None
        mean, std_dev = np.mean(concentrations), np.std(concentrations, ddof=1)
        return (std_dev / mean) * 100 if mean != 0 else 0

# Input data
plate_data_x = {
    "A": [100, 50, 25, 12.5, 6.25, 0],
    "B": [100, 50, 25, 12.5, 6.25, 0],
    "C": [100, 50, 25, 12.5, 6.25, 0],
    "D": [0, 0, 0, 0, 0, 0]
}

plate_data_y = {
    "A": [640, 426, 256, 130, 70, 8],
    "B": [664, 400, 240, 134, 76, 10],
    "C": [635, 410, 244, 136, 73, 8],
    "D": [400, 421, 301, 290, 9, 9]
}

plate_format = {
    'A': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'B': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'C': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'D': ['U1', 'U1', 'U2', 'U2', 'X', 'B']
}

# Create cubic spline interpolator instance
interpolator = CubicSplineInterpolator(plate_data_x, plate_data_y, plate_format)

# Extract and print coefficients
interpolator.extract_coefficients()

# Predict and print x values for each y in plate_data_y - the additional table
interpolator.predict_and_print_x_values()
