%% test_neural_starG.m - Standalone test for Neural Star_G
clear; clc;

fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf('  TESTING NEURAL ★_G\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');

%% Generate data
n_samples = 500;
n_feat = 20;
n_rot = 12;

G = StarGAlgebra('cyclic', n_rot);

fprintf('Generating synthetic data...\n');
[X_data, Y_data] = generateMolecularData(n_samples, n_feat, n_rot);

% Normalize
X_data = (X_data - mean(X_data(:))) / std(X_data(:));
Y_data = (Y_data - mean(Y_data)) / std(Y_data);

fprintf('Data shape: [%d, %d, %d]\n', size(X_data));
fprintf('Target range: [%.2f, %.2f]\n\n', min(Y_data), max(Y_data));

%% Split
n_train = round(0.7 * n_samples);
n_val = round(0.15 * n_samples);

perm = randperm(n_samples);
train_idx = perm(1:n_train);
val_idx = perm(n_train+1:n_train+n_val);
test_idx = perm(n_train+n_val+1:end);

X_train = X_data(train_idx, :, :);
X_val = X_data(val_idx, :, :);
X_test = X_data(test_idx, :, :);
Y_train = Y_data(train_idx);
Y_val = Y_data(val_idx);
Y_test = Y_data(test_idx);

%% Extract invariant features
fprintf('Extracting invariant features...\n');
[train_feat, norm_params] = extractStarGFeatures(X_train, G, n_feat);
val_feat = extractStarGFeatures(X_val, G, n_feat, norm_params);
test_feat = extractStarGFeatures(X_test, G, n_feat, norm_params);

fprintf('Feature dimensions: %d\n', size(train_feat, 2));

%% Check feature statistics
fprintf('\nFeature statistics:\n');
fprintf('  Train: mean=%.4f, std=%.4f, min=%.4f, max=%.4f\n', ...
        mean(train_feat(:)), std(train_feat(:)), min(train_feat(:)), max(train_feat(:)));
fprintf('  Val:   mean=%.4f, std=%.4f, min=%.4f, max=%.4f\n', ...
        mean(val_feat(:)), std(val_feat(:)), min(val_feat(:)), max(val_feat(:)));

%% Method 1: Ridge regression (baseline)
fprintf('\n,  Ridge Regression , \n');

n_f = size(train_feat, 2);
lambda = 1.0;
W_ridge = (train_feat' * train_feat + lambda * eye(n_f)) \ (train_feat' * Y_train);

Y_pred_ridge = test_feat * W_ridge;
r2_ridge = computeR2(Y_pred_ridge, Y_test);
fprintf('Ridge R² = %.4f\n', r2_ridge);

%% Method 2: MLP
fprintf('\n,  MLP Training , \n');

hidden_layers = [64, 32];
n_input = size(train_feat, 2);
fprintf('Architecture: [%d, %s, 1]\n', n_input, mat2str(hidden_layers));

[W_nn, b_nn, history] = trainMLP_v2(train_feat, Y_train, val_feat, Y_val, hidden_layers, ...
                                     'epochs', 300, ...
                                     'learningRate', 0.01, ...
                                     'patience', 30, ...
                                     'verbose', true);

Y_pred_nn = forwardMLP_v2(test_feat, W_nn, b_nn);
r2_nn = computeR2(Y_pred_nn, Y_test);
fprintf('\nMLP R² = %.4f\n', r2_nn);

%% Test invariance
fprintf('\n,  Rotation Invariance Test , \n');

rot_r2_ridge = zeros(n_rot, 1);
rot_r2_nn = zeros(n_rot, 1);

for g = 1:n_rot
    X_test_rot = circshift(X_test, g-1, 3);
    feat_rot = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
    
    Y_pred_ridge_rot = feat_rot * W_ridge;
    rot_r2_ridge(g) = computeR2(Y_pred_ridge_rot, Y_test);
    
    Y_pred_nn_rot = forwardMLP_v2(feat_rot, W_nn, b_nn);
    rot_r2_nn(g) = computeR2(Y_pred_nn_rot, Y_test);
end

fprintf('Ridge: R²=%.4f, σ=%.2e\n', mean(rot_r2_ridge), std(rot_r2_ridge));
fprintf('MLP:   R²=%.4f, σ=%.2e\n', mean(rot_r2_nn), std(rot_r2_nn));

%% Summary
fprintf('\n═══════════════════════════════════════════════════════════════════\n');
fprintf('  SUMMARY\n');
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf('Ridge Regression: R² = %.4f, rotation σ = %.2e\n', r2_ridge, std(rot_r2_ridge));
fprintf('Neural ★_G (MLP): R² = %.4f, rotation σ = %.2e\n', r2_nn, std(rot_r2_nn));
fprintf('═══════════════════════════════════════════════════════════════════\n');

%% Plot
figure('Position', [100, 100, 1000, 400]);

subplot(1, 3, 1);
plot(history.train_loss, 'b-', 'LineWidth', 1.5);
hold on;
plot(history.val_loss, 'r-', 'LineWidth', 1.5);
xlabel('Epoch');
ylabel('Loss');
legend('Train', 'Val');
title('Training History');
grid on;

subplot(1, 3, 2);
scatter(Y_test, Y_pred_ridge, 50, 'b', 'filled', 'MarkerFaceAlpha', 0.5);
hold on;
scatter(Y_test, Y_pred_nn, 50, 'r', 'filled', 'MarkerFaceAlpha', 0.5);
plot([-3, 3], [-3, 3], 'k--', 'LineWidth', 1.5);
xlabel('True');
ylabel('Predicted');
legend('Ridge', 'MLP');
title('Predictions vs True');
grid on;

subplot(1, 3, 3);
angles = (0:n_rot-1) * 360 / n_rot;
plot(angles, rot_r2_ridge, 'b-o', 'LineWidth', 2, 'MarkerFaceColor', 'b');
hold on;
plot(angles, rot_r2_nn, 'r-s', 'LineWidth', 2, 'MarkerFaceColor', 'r');
xlabel('Rotation Angle (°)');
ylabel('R²');
legend('Ridge', 'MLP');
title('R² vs Rotation');
grid on;

sgtitle('Neural ★_G Performance');


%% ========================================================================
%% HELPER FUNCTIONS
%% ========================================================================

function r2 = computeR2(y_pred, y_true)
    y_pred = y_pred(:);
    y_true = y_true(:);
    ss_res = sum((y_pred - y_true).^2);
    ss_tot = sum((y_true - mean(y_true)).^2) + 1e-10;
    r2 = 1 - ss_res / ss_tot;
end


function [X, Y] = generateMolecularData(n_samples, n_feat, n_rot)
    X = zeros(n_samples, n_feat, n_rot);
    Y = zeros(n_samples, 1);
    
    for i = 1:n_samples
        n_atoms = randi([4, 10]);
        pos = randn(n_atoms, 3) * 2;
        pos = pos - mean(pos, 1);
        Z = randi([1, 9], n_atoms, 1);
        Z(Z > 1 & Z < 6) = 6;
        
        dists = pdist(pos);
        if isempty(dists), dists = 1; end
        Y(i) = mean(dists) + 0.3 * std(dists) + sum(Z.^2) / 500;
        
        for g = 1:n_rot
            theta = 2 * pi * (g - 1) / n_rot;
            R = [cos(theta), -sin(theta), 0; sin(theta), cos(theta), 0; 0, 0, 1];
            pos_rot = (R * pos')';
            
            feat = zeros(n_feat, 1);
            for a = 1:n_atoms
                x = pos_rot(a, 1);
                y = pos_rot(a, 2);
                z = pos_rot(a, 3);
                r = norm(pos_rot(a, :));
                
                feat(1) = feat(1) + Z(a) * x;
                feat(2) = feat(2) + Z(a) * y;
                feat(3) = feat(3) + Z(a) * z;
                feat(4) = feat(4) + Z(a) * r;
                feat(5) = feat(5) + Z(a) * r^2;
                
                for f = 6:n_feat
                    feat(f) = feat(f) + Z(a) * cos((f-5) * atan2(y, x)) * exp(-r^2 / 8);
                end
            end
            X(i, :, g) = feat';
        end
    end
end


function [W, b, history] = trainMLP_v2(X_train, Y_train, X_val, Y_val, hidden_layers, varargin)
    p = inputParser;
    addParameter(p, 'epochs', 200);
    addParameter(p, 'learningRate', 0.01);
    addParameter(p, 'batchSize', 32);
    addParameter(p, 'patience', 20);
    addParameter(p, 'verbose', false);
    parse(p, varargin{:});
    
    epochs = p.Results.epochs;
    lr = p.Results.learningRate;
    batch_size = p.Results.batchSize;
    patience = p.Results.patience;
    verbose = p.Results.verbose;
    
    Y_train = Y_train(:);
    Y_val = Y_val(:);
    
    n_input = size(X_train, 2);
    n_output = 1;
    layer_sizes = [n_input, hidden_layers, n_output];
    n_layers = length(layer_sizes) - 1;
    
    W = cell(n_layers, 1);
    b = cell(n_layers, 1);
    
    for l = 1:n_layers
        fan_in = layer_sizes(l);
        fan_out = layer_sizes(l+1);
        W{l} = randn(fan_out, fan_in) * sqrt(2 / fan_in);
        b{l} = zeros(fan_out, 1);
    end
    
    history = struct();
    history.train_loss = zeros(epochs, 1);
    history.val_loss = zeros(epochs, 1);
    
    best_val_loss = inf;
    best_W = W;
    best_b = b;
    wait = 0;
    
    n_train = size(X_train, 1);
    n_batches = ceil(n_train / batch_size);
    
    m_W = cell(n_layers, 1);
    v_W = cell(n_layers, 1);
    m_b = cell(n_layers, 1);
    v_b = cell(n_layers, 1);
    
    for l = 1:n_layers
        m_W{l} = zeros(size(W{l}));
        v_W{l} = zeros(size(W{l}));
        m_b{l} = zeros(size(b{l}));
        v_b{l} = zeros(size(b{l}));
    end
    
    beta1 = 0.9;
    beta2 = 0.999;
    eps = 1e-8;
    t = 0;
    
    for epoch = 1:epochs
        perm = randperm(n_train);
        epoch_loss = 0;
        
        for batch = 1:n_batches
            t = t + 1;
            
            batch_start = (batch - 1) * batch_size + 1;
            batch_end = min(batch * batch_size, n_train);
            batch_idx = perm(batch_start:batch_end);
            
            X_batch = X_train(batch_idx, :);
            Y_batch = Y_train(batch_idx);
            n_batch = length(batch_idx);
            
            A = cell(n_layers + 1, 1);
            Z = cell(n_layers, 1);
            A{1} = X_batch';
            
            for l = 1:n_layers
                Z{l} = W{l} * A{l} + b{l};
                if l < n_layers
                    A{l+1} = max(0, Z{l});
                else
                    A{l+1} = Z{l};
                end
            end
            
            Y_pred_batch = A{n_layers + 1}';
            batch_loss = mean((Y_pred_batch - Y_batch).^2);
            epoch_loss = epoch_loss + batch_loss;
            
            dA = 2 * (A{n_layers + 1} - Y_batch') / n_batch;
            
            for l = n_layers:-1:1
                if l < n_layers
                    dZ = dA .* (Z{l} > 0);
                else
                    dZ = dA;
                end
                
                dW = dZ * A{l}' / n_batch;
                db = mean(dZ, 2);
                
                if l > 1
                    dA = W{l}' * dZ;
                end
                
                m_W{l} = beta1 * m_W{l} + (1 - beta1) * dW;
                v_W{l} = beta2 * v_W{l} + (1 - beta2) * (dW.^2);
                m_b{l} = beta1 * m_b{l} + (1 - beta1) * db;
                v_b{l} = beta2 * v_b{l} + (1 - beta2) * (db.^2);
                
                m_W_hat = m_W{l} / (1 - beta1^t);
                v_W_hat = v_W{l} / (1 - beta2^t);
                m_b_hat = m_b{l} / (1 - beta1^t);
                v_b_hat = v_b{l} / (1 - beta2^t);
                
                W{l} = W{l} - lr * m_W_hat ./ (sqrt(v_W_hat) + eps);
                b{l} = b{l} - lr * m_b_hat ./ (sqrt(v_b_hat) + eps);
            end
        end
        
        history.train_loss(epoch) = epoch_loss / n_batches;
        
        Y_val_pred = forwardMLP_v2(X_val, W, b);
        history.val_loss(epoch) = mean((Y_val_pred - Y_val).^2);
        
        if history.val_loss(epoch) < best_val_loss
            best_val_loss = history.val_loss(epoch);
            best_W = W;
            best_b = b;
            wait = 0;
        else
            wait = wait + 1;
        end
        
        if verbose && mod(epoch, 20) == 0
            Y_train_pred = forwardMLP_v2(X_train, W, b);
            train_r2 = computeR2_local(Y_train_pred, Y_train);
            val_r2 = computeR2_local(Y_val_pred, Y_val);
            fprintf('    Epoch %3d: loss=%.4f, val_loss=%.4f, R2=%.3f, val_R2=%.3f\n', ...
                    epoch, history.train_loss(epoch), history.val_loss(epoch), train_r2, val_r2);
        end
        
        if wait >= patience
            if verbose
                fprintf('    Early stopping at epoch %d\n', epoch);
            end
            break;
        end
    end
    
    W = best_W;
    b = best_b;
    history.train_loss = history.train_loss(1:epoch);
    history.val_loss = history.val_loss(1:epoch);
    
    function r2 = computeR2_local(y_pred, y_true)
        y_pred = y_pred(:);
        y_true = y_true(:);
        ss_res = sum((y_pred - y_true).^2);
        ss_tot = sum((y_true - mean(y_true)).^2) + 1e-10;
        r2 = 1 - ss_res / ss_tot;
    end
end


function Y = forwardMLP_v2(X, W, b)
    A = X';
    n_layers = length(W);
    
    for l = 1:n_layers
        Z = W{l} * A + b{l};
        if l < n_layers
            A = max(0, Z);
        else
            A = Z;
        end
    end
    
    Y = A';
end