import pytest
import numpy as np
from calculation_methods.py_calculations.linear_regression import LinearRegression
from calculation.py_tests.unit_test_utils import Utils 

@pytest.fixture
def combined_sample_data():
    # Use Utils.generate_sample_data() to generate the sample data dynamically.
    sample_data_dict = Utils.generate_sample_data()
    
    # Extract components from the dictionary
    input_data = sample_data_dict["input_data"]
    standard_values = sample_data_dict["standard_values"]
    blank_values = sample_data_dict["blank_values"]
    
    return input_data, standard_values, blank_values

# Test case for standard samples
def test_curve_y_for_standard_samples(combined_sample_data):
    # Arrange
    input_data, standard_values, blank_values = combined_sample_data
    weighted = False
    forced_to_zero = False

    # Act
    model = LinearRegression(input_data, standard_values, blank_values, weighted, forced_to_zero)
    additional_table = model.execute().get("Additional_Table_Details", [])

    # Extract curve_y values for standard samples
    standard_sample_curve_y = [entry['curve_y'] for entry in additional_table if entry['identifier'].upper().startswith('S')]

    # Assert
    assert len(standard_sample_curve_y) == len(set(entry['identifier'] for entry in standard_values))
    for curve_y in standard_sample_curve_y:
        assert isinstance(curve_y, list)
        assert all(isinstance(y, (float, np.float64)) for y in curve_y)
        assert len(curve_y) == 1

# Test case for non-standard samples
def test_curve_y_not_generated_for_non_standard_samples(combined_sample_data):
    # Arrange
    input_data, standard_values, blank_values = combined_sample_data
    weighted = False
    forced_to_zero = True

    # Act
    model = LinearRegression(input_data, standard_values, blank_values, weighted, forced_to_zero)
    additional_table = model.execute().get("Additional_Table_Details", [])

    # Identify non-standard samples (blanks, unknowns, etc.)
    non_standard_samples = [
        entry for entry in additional_table
        if not entry['identifier'].upper().startswith('S')  # Exclude standard samples (S1, S2, etc.)
    ]

    # Assert
    for entry in non_standard_samples:
        assert 'curve_y' not in entry or entry['curve_y'] == []

# Test case for mixed identifiers (standard & non-standard)
def test_mixed_identifiers_curve_y(combined_sample_data):
    # Arrange
    input_data, standard_values, blank_values = combined_sample_data
    weighted = True
    forced_to_zero = True

    # Act
    model = LinearRegression(input_data, standard_values, blank_values, weighted, forced_to_zero)
    additional_table = model.execute().get("Additional_Table_Details", [])

    # Separate standard and non-standard samples
    standard_samples = [entry for entry in additional_table if entry['identifier'].upper().startswith('S')]
    non_standard_samples = [entry for entry in additional_table if not entry['identifier'].upper().startswith('S')]

    # Assert for standard samples: curve_y should exist and not be empty
    for entry in standard_samples:
        assert 'curve_y' in entry and entry['curve_y']

    # Assert for non-standard samples: curve_y should be empty
    for entry in non_standard_samples:
        assert 'curve_y' not in entry or entry['curve_y'] == []