import numpy as np
#from scipy.interpolate import CubicSpline
from tabulate import tabulate

class CubicSplineInterpolator:
    def __init__(self, plate_data_x, plate_data_y, plate_format):
        self.X_unique, self.Y_unique = self.calculate_unique_values(plate_data_x, plate_data_y, plate_format)
        if not np.all(np.diff(self.X_unique) > 0):
            raise ValueError("X values must be sorted in increasing order.")
        #self.cs = CubicSpline(self.X_unique, self.Y_unique, bc_type='natural')
        self.coefficients = self.cs.c.T
        self.plate_format = plate_format
        self.plate_data_y = plate_data_y

    def calculate_unique_values(self, plate_data_x, plate_data_y, plate_format):
        def calculate_average_x(plate_data_x, plate_format):
            average_x_values = {}
            all_s_slots = set()
            for slots in plate_format.values():
                all_s_slots.update([slot for slot in slots if slot.startswith('S')])
            for slot in all_s_slots:
                x_values = []
                for key, slots in plate_format.items():
                    if slot in slots:
                        indices = [i for i, s in enumerate(slots) if s == slot]
                        x_values.extend([plate_data_x[key][i] for i in indices])
                average_x_values[slot] = np.mean(x_values)
            return average_x_values

        def calculate_average_y(plate_data_y, plate_format):
            average_y_values = {}
            all_s_slots = set()
            for slots in plate_format.values():
                all_s_slots.update([slot for slot in slots if slot.startswith('S')])
            for slot in all_s_slots:
                y_values = []
                for key, slots in plate_format.items():
                    if slot in slots:
                        indices = [i for i, s in enumerate(slots) if s == slot]
                        y_values.extend([plate_data_y[key][i] for i in indices])
                average_y_values[slot] = np.mean(y_values)
            return average_y_values

        average_x_values = calculate_average_x(plate_data_x, plate_format)
        average_y_values = calculate_average_y(plate_data_y, plate_format)

        X_unique = np.array([average_x_values[slot] for slot in sorted(average_x_values.keys(), key=lambda x: average_x_values[x])])
        Y_unique = np.array([average_y_values[slot] for slot in sorted(average_y_values.keys(), key=lambda x: average_x_values[x])])

        return X_unique, Y_unique

    def extract_coefficients(self):
        def extract_segment_coeffs(segment):
            a = self.cs.c[3, segment]
            b = self.cs.c[2, segment]
            c = self.cs.c[1, segment]
            d = self.cs.c[0, segment]
            x_start = self.X_unique[segment]
            x_end = self.X_unique[segment + 1]
            return x_start, x_end, a, b, c, d

        print("Extracted coefficients:")
        for i in range(len(self.X_unique) - 1):
            x_start, x_end, a, b, c, d = extract_segment_coeffs(i)
            print(f"x_{i+1} = {x_start}")
            print(f"x_{i+2} = {x_end}")
            print(f"a_{i+1} = {a}")
            print(f"b_{i+1} = {b}")
            print(f"c_{i+1} = {c}")
            print(f"d_{i+1} = {d}")
            print()

    def predict_x(self, y_val):
        for i in range(len(self.X_unique) - 1):
            d, c, b, a = self.coefficients[i]
            coeffs = [d, c, b, a - y_val]
            roots = np.roots(coeffs)
            real_roots = [root.real for root in roots if np.isreal(root) and self.X_unique[i] <= root.real <= self.X_unique[i + 1]]
            if real_roots:
                return real_roots[0]
        return None

    def predict_and_print_x_values(self):
        print("\nAdditional table")
        additional_table = []
        for plate, wells in self.plate_format.items():
            y_values = self.plate_data_y[plate]
            for i, well in enumerate(wells):
                y_val = y_values[i]
                x_pred = self.predict_x(y_val)
                if x_pred is not None:
                    cv = self.calculate_cv(plate)
                    concentration_difference = ((x_pred - y_val) / y_val) * 100
                    additional_table.append([well, plate + str(i + 1), y_val, x_pred, cv, concentration_difference])
                else:
                    additional_table.append([well, plate + str(i + 1), y_val, '', '', ''])

        headers = ["Wells", "Plate", "FI - EndPoint", "CONCENTRATION", "%CV", "CONCENTRATION DIFFERENCE (%)"]
        print(tabulate(additional_table, headers=headers, tablefmt="grid"))

    def calculate_cv(self, plate):
        concentrations = []
        y_values = self.plate_data_y[plate]
        for y_val in y_values:
            concentration_predicted = self.predict_x(y_val)
            if concentration_predicted is not None:
                concentrations.append(concentration_predicted)
        if len(concentrations) < 2:
            return None
        mean, std_dev = np.mean(concentrations), np.std(concentrations, ddof=1)
        return (std_dev / mean) * 100 if mean != 0 else 0

# Input data
plate_data_x = {
    "A": [100, 50, 25, 12.5, 6.25, 0],
    "B": [100, 50, 25, 12.5, 6.25, 0],
    "C": [100, 50, 25, 12.5, 6.25, 0],
    "D": [0, 0, 0, 0, 0, 0]
}

plate_data_y = {
    "A": [640, 426, 256, 130, 70, 8],
    "B": [664, 400, 240, 134, 76, 10],
    "C": [635, 410, 244, 136, 73, 8],
    "D": [400, 421, 301, 290, 9, 9]
}

plate_format = {
    'A': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'B': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'C': ['S5', 'S4', 'S3', 'S2', 'S1', 'B'],
    'D': ['U1', 'U1', 'U2', 'U2', 'X', 'B']
}

# Create cubic spline interpolator instance
interpolator = CubicSplineInterpolator(plate_data_x, plate_data_y, plate_format)

# Extract and print coefficients
interpolator.extract_coefficients()

# Predict and print x values for each y in plate_data_y- the additional table 
interpolator.predict_and_print_x_values()
