tensor-group-sym / experiments / run_main_comparison.m
run_main_comparison.m
Raw
%% run_neural_starG_vs_ENN.m
%% Neural _G vs Equivariant Neural Network Comparison
%% For Nature Communications
%% ============================================================================
%%
%% This script compares:
%%   1. _G-SVD + Ridge (linear, exact invariance) - OUR METHOD
%%   2. Neural _G (non-linear, exact invariance) - OUR METHOD
%%   3. Standard MLP (non-linear, NO invariance) - BASELINE
%%   4. Invariant MLP (non-linear, approximate invariance) - ENN PROXY
%%   5. Augmented MLP (non-linear, learned invariance) - ENN PROXY
%%
%% LH & SU & Claude 2026
%% ============================================================================

clear; clc; close all;

fprintf('\n');
fprintf('  Neural _G vs Equivariant Neural Networks\n');
fprintf('  Comprehensive Comparison for Nature Communications\n');
fprintf('\n\n');

resultsDir = 'neural_vs_enn_results';
if ~exist(resultsDir, 'dir'), mkdir(resultsDir); end

%% ========================================================================
%% Configuration
%% ========================================================================

n_samples = 1500;
n_rot = 12;
n_feat = 20;
n_runs = 3;

fprintf('Configuration:\n');
fprintf('  Samples: %d\n', n_samples);
fprintf('  Rotations: %d\n', n_rot);
fprintf('  Features: %d\n', n_feat);
fprintf('  Runs: %d\n\n', n_runs);

%% ========================================================================
%% Generate Data
%% ========================================================================

fprintf('\n');
fprintf('  Generating Molecular Dataset\n');
fprintf('\n\n');

G = StarGAlgebra('cyclic', n_rot);

[X_data, Y_data] = generateMolecularData(n_samples, n_feat, n_rot);

% Normalize
X_data = (X_data - mean(X_data(:))) / (std(X_data(:)) + 1e-10);
Y_mean = mean(Y_data);
Y_std = std(Y_data) + 1e-10;
Y_data = (Y_data - Y_mean) / Y_std;

fprintf('Data shape: [%d, %d, %d]\n', size(X_data));
fprintf('Target range: [%.2f, %.2f]\n\n', min(Y_data), max(Y_data));

%% ========================================================================
%% Train/Val/Test Split
%% ========================================================================

rng(42);
n_train = round(0.7 * n_samples);
n_val = round(0.15 * n_samples);
n_test = n_samples - n_train - n_val;

perm = randperm(n_samples);
train_idx = perm(1:n_train);
val_idx = perm(n_train+1:n_train+n_val);
test_idx = perm(n_train+n_val+1:end);

X_train = X_data(train_idx, :, :);
X_val = X_data(val_idx, :, :);
X_test = X_data(test_idx, :, :);
Y_train = Y_data(train_idx);
Y_val = Y_data(val_idx);
Y_test = Y_data(test_idx);

fprintf('Split: Train=%d, Val=%d, Test=%d\n\n', n_train, n_val, n_test);

%% ========================================================================
%% Verify Exact Feature Invariance
%% ========================================================================

fprintf('\n');
fprintf('  Verifying Exact Feature Invariance                             \n');
fprintf('\n\n');

n_verify = min(10, size(X_train, 1));
test_samples = X_train(1:n_verify, :, :);

[feat_batch, norm_params_test] = extractStarGFeatures(test_samples, G, n_feat);

test_sample = X_train(1:1, :, :);
feat_orig = extractStarGFeatures(test_sample, G, n_feat, norm_params_test);

fprintf('Test sample shape: [%d, %d, %d]\n', size(test_sample));
fprintf('Feature vector length: %d\n\n', size(feat_orig, 2));

fprintf('Rotation  Angle   Max |Δfeature|\n');
fprintf(', , --  , --   , , , , --\n');

max_diff_all = 0;
for g = 1:n_rot
    test_rot = circshift(test_sample, g-1, 3);
    feat_rot = extractStarGFeatures(test_rot, G, n_feat, norm_params_test);
    
    diff = max(abs(feat_orig(:) - feat_rot(:)));
    max_diff_all = max(max_diff_all, diff);
    
    angle = (g-1) * 360 / n_rot;
    fprintf('   %2d      %3d°    %.2e\n', g, angle, diff);
end

fprintf('\nMaximum difference: %.2e\n', max_diff_all);

if max_diff_all < 1e-12
    fprintf(' Features are EXACTLY invariant (machine precision)\n\n');
elseif max_diff_all < 1e-8
    fprintf(' Features are effectively invariant\n\n');
else
    fprintf(' WARNING: Features are NOT invariant!\n\n');
end

%% ========================================================================
%% Initialize Results Storage
%% ========================================================================

methods = {'StarG_SVD', 'Neural_StarG', 'MLP_r1', 'Invariant_MLP', 'Augmented_MLP'};
n_methods = length(methods);

results = struct();
for m = 1:n_methods
    results.(methods{m}).test_r2 = zeros(n_runs, 1);
    results.(methods{m}).rotation_r2 = zeros(n_runs, n_rot);
    results.(methods{m}).rotation_std = zeros(n_runs, 1);
    results.(methods{m}).n_params = 0;
    results.(methods{m}).train_time = zeros(n_runs, 1);
end

%% ========================================================================
%% Run Multiple Experiments
%% ========================================================================

for run = 1:n_runs
    fprintf('\n');
    fprintf('  RUN %d / %d\n', run, n_runs);
    fprintf('\n\n');
    
    rng(run * 123);
    
    %% ====================================================================
    %% Method 1: _G-SVD + Ridge (Exact Invariance)
    %% ====================================================================
    
    fprintf(',  Method 1: _G-SVD + Ridge (Exact Invariance) , \n');
    
    tic;
    
    [train_feat, norm_params] = extractStarGFeatures(X_train, G, n_feat);
    val_feat = extractStarGFeatures(X_val, G, n_feat, norm_params);
    test_feat = extractStarGFeatures(X_test, G, n_feat, norm_params);
    
    n_extracted = size(train_feat, 2);
    fprintf('  Exact invariant features: %d\n', n_extracted);
    
    lambdas = [0.001, 0.01, 0.1, 1, 10];
    best_lambda = 0.1;
    best_val_r2 = -inf;
    
    for lam = lambdas
        W_temp = (train_feat' * train_feat + lam * eye(n_extracted)) \ (train_feat' * Y_train);
        Y_val_pred = val_feat * W_temp;
        val_r2 = computeR2(Y_val_pred, Y_val);
        if val_r2 > best_val_r2
            best_val_r2 = val_r2;
            best_lambda = lam;
        end
    end
    
    fprintf('  Best lambda: %.4f\n', best_lambda);
    
    W_svd = (train_feat' * train_feat + best_lambda * eye(n_extracted)) \ (train_feat' * Y_train);
    results.StarG_SVD.train_time(run) = toc;
    
    Y_pred = test_feat * W_svd;
    results.StarG_SVD.test_r2(run) = computeR2(Y_pred, Y_test);
    results.StarG_SVD.n_params = n_extracted;
    
    rot_r2 = zeros(n_rot, 1);
    for g = 1:n_rot
        X_test_rot = circshift(X_test, g-1, 3);
        feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
        Y_pred_g = feat_g * W_svd;
        rot_r2(g) = computeR2(Y_pred_g, Y_test);
    end
    results.StarG_SVD.rotation_r2(run, :) = rot_r2';
    results.StarG_SVD.rotation_std(run) = std(rot_r2);
    
    fprintf('  R² = %.4f, σ = %.2e, params = %d\n', ...
        results.StarG_SVD.test_r2(run), results.StarG_SVD.rotation_std(run), n_extracted);
    
    if results.StarG_SVD.rotation_std(run) < 1e-10
        fprintf('   EXACT invariance achieved!\n\n');
    else
        fprintf('  Rotation variance: %.2e\n\n', results.StarG_SVD.rotation_std(run));
    end
    
    %% ====================================================================
    %% Method 2: Neural _G (Exact Invariance)
    %% ====================================================================
    
    fprintf(',  Method 2: Neural _G (Exact Invariance) , \n');
    
    tic;
    
    train_feat_nn = train_feat;
    val_feat_nn = val_feat;
    test_feat_nn = test_feat;
    
    n_pca = min(50, n_extracted);
    if n_extracted > 50
        [coeff_pca, ~, ~] = pca(train_feat);
        train_feat_nn = train_feat * coeff_pca(:, 1:n_pca);
        val_feat_nn = val_feat * coeff_pca(:, 1:n_pca);
        test_feat_nn = test_feat * coeff_pca(:, 1:n_pca);
        pca_coeff = coeff_pca(:, 1:n_pca);
        use_pca = true;
    else
        pca_coeff = eye(n_extracted);
        n_pca = n_extracted;
        use_pca = false;
    end
    
    fprintf('  Features for MLP: %d\n', n_pca);
    
    hidden_layers = [32, 16];
    [W_nn, b_nn] = trainMLP(train_feat_nn, Y_train, val_feat_nn, Y_val, ...
                            hidden_layers, 200, 0.005, false);
    results.Neural_StarG.train_time(run) = toc;
    
    Y_pred = forwardMLP(test_feat_nn, W_nn, b_nn);
    results.Neural_StarG.test_r2(run) = computeR2(Y_pred, Y_test);
    results.Neural_StarG.n_params = countMLPParams([n_pca, hidden_layers, 1]);
    
    rot_r2 = zeros(n_rot, 1);
    for g = 1:n_rot
        X_test_rot = circshift(X_test, g-1, 3);
        feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
        
        if use_pca
            feat_g = feat_g * pca_coeff;
        end
        
        Y_pred_g = forwardMLP(feat_g, W_nn, b_nn);
        rot_r2(g) = computeR2(Y_pred_g, Y_test);
    end
    results.Neural_StarG.rotation_r2(run, :) = rot_r2';
    results.Neural_StarG.rotation_std(run) = std(rot_r2);
    
    fprintf('  R² = %.4f, σ = %.2e, params = %d\n', ...
        results.Neural_StarG.test_r2(run), results.Neural_StarG.rotation_std(run), ...
        results.Neural_StarG.n_params);
    
    if results.Neural_StarG.rotation_std(run) < 1e-10
        fprintf('   EXACT invariance achieved!\n\n');
    else
        fprintf('  Rotation variance: %.2e\n\n', results.Neural_StarG.rotation_std(run));
    end
    
    %% ====================================================================
    %% Method 3: Standard MLP (No Invariance)
    %% ====================================================================
    
    fprintf(',  Method 3: Standard MLP (rotation 1 only) , \n');
    
    X_train_r1 = squeeze(X_train(:, :, 1));
    X_val_r1 = squeeze(X_val(:, :, 1));
    X_test_r1 = squeeze(X_test(:, :, 1));
    
    tic;
    [W_mlp, b_mlp] = trainMLP(X_train_r1, Y_train, X_val_r1, Y_val, [64, 32], 100, 0.005, false);
    results.MLP_r1.train_time(run) = toc;
    
    Y_pred = forwardMLP(X_test_r1, W_mlp, b_mlp);
    results.MLP_r1.test_r2(run) = computeR2(Y_pred, Y_test);
    results.MLP_r1.n_params = countMLPParams([n_feat, 64, 32, 1]);
    
    for g = 1:n_rot
        X_test_g = squeeze(X_test(:, :, g));
        Y_pred_g = forwardMLP(X_test_g, W_mlp, b_mlp);
        results.MLP_r1.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
    end
    results.MLP_r1.rotation_std(run) = std(results.MLP_r1.rotation_r2(run, :));
    
    fprintf('  R² = %.4f, σ = %.4f, params = %d\n\n', ...
        results.MLP_r1.test_r2(run), results.MLP_r1.rotation_std(run), results.MLP_r1.n_params);
    
    %% ====================================================================
    %% Method 4: Invariant MLP (ENN Proxy)
    %% ====================================================================
    
    fprintf(',  Method 4: Invariant MLP (ENN-style) , \n');
    
    X_train_inv = computeInvariantFeatures(X_train);
    X_val_inv = computeInvariantFeatures(X_val);
    X_test_inv = computeInvariantFeatures(X_test);
    
    n_inv_feat = size(X_train_inv, 2);
    
    tic;
    [W_inv, b_inv] = trainMLP(X_train_inv, Y_train, X_val_inv, Y_val, [64, 32], 100, 0.005, false);
    results.Invariant_MLP.train_time(run) = toc;
    
    Y_pred = forwardMLP(X_test_inv, W_inv, b_inv);
    results.Invariant_MLP.test_r2(run) = computeR2(Y_pred, Y_test);
    results.Invariant_MLP.n_params = countMLPParams([n_inv_feat, 64, 32, 1]);
    
    for g = 1:n_rot
        X_test_rot = circshift(X_test, g-1, 3);
        X_test_inv_g = computeInvariantFeatures(X_test_rot);
        Y_pred_g = forwardMLP(X_test_inv_g, W_inv, b_inv);
        results.Invariant_MLP.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
    end
    results.Invariant_MLP.rotation_std(run) = std(results.Invariant_MLP.rotation_r2(run, :));
    
    fprintf('  R² = %.4f, σ = %.2e, params = %d\n\n', ...
        results.Invariant_MLP.test_r2(run), results.Invariant_MLP.rotation_std(run), ...
        results.Invariant_MLP.n_params);
    
    %% ====================================================================
    %% Method 5: Augmented MLP (Data Augmentation)
    %% ====================================================================
    
    fprintf(',  Method 5: MLP with Full Augmentation , \n');
    
    X_train_aug = zeros(n_train * n_rot, n_feat);
    Y_train_aug = zeros(n_train * n_rot, 1);
    for i = 1:n_train
        for g = 1:n_rot
            idx = (i-1) * n_rot + g;
            X_train_aug(idx, :) = X_train(i, :, g);
            Y_train_aug(idx) = Y_train(i);
        end
    end
    
    tic;
    [W_aug, b_aug] = trainMLP(X_train_aug, Y_train_aug, X_val_r1, Y_val, [64, 32], 80, 0.005, false);
    results.Augmented_MLP.train_time(run) = toc;
    
    Y_pred = forwardMLP(X_test_r1, W_aug, b_aug);
    results.Augmented_MLP.test_r2(run) = computeR2(Y_pred, Y_test);
    results.Augmented_MLP.n_params = countMLPParams([n_feat, 64, 32, 1]);
    
    for g = 1:n_rot
        X_test_g = squeeze(X_test(:, :, g));
        Y_pred_g = forwardMLP(X_test_g, W_aug, b_aug);
        results.Augmented_MLP.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
    end
    results.Augmented_MLP.rotation_std(run) = std(results.Augmented_MLP.rotation_r2(run, :));
    
    fprintf('  R² = %.4f, σ = %.4f, params = %d\n\n', ...
        results.Augmented_MLP.test_r2(run), results.Augmented_MLP.rotation_std(run), ...
        results.Augmented_MLP.n_params);
end

%% ========================================================================
%% Aggregate Results
%% ========================================================================

fprintf('\n');
fprintf('\n');
fprintf('  AGGREGATED RESULTS (Mean ± Std over %d runs)\n', n_runs);
fprintf('\n\n');

fprintf('%-20s  %-15s  %-20s  %-10s  %-12s\n', ...
    'Method', 'Test R²', 'Rotation σ', 'Params', 'Time (s)');
fprintf('%s\n', repmat('-', 1, 85));

summary = struct();
for m = 1:n_methods
    method = methods{m};
    
    mean_r2 = mean(results.(method).test_r2);
    std_r2 = std(results.(method).test_r2);
    mean_rot_std = mean(results.(method).rotation_std);
    std_rot_std = std(results.(method).rotation_std);
    mean_time = mean(results.(method).train_time);
    n_params = results.(method).n_params;
    
    summary.(method).mean_r2 = mean_r2;
    summary.(method).std_r2 = std_r2;
    summary.(method).mean_rot_std = mean_rot_std;
    summary.(method).n_params = n_params;
    summary.(method).mean_time = mean_time;
    
    fprintf('%-20s  %.4f ± %.4f   %.2e ± %.2e   %-10d  %.2f\n', ...
        strrep(method, '_', ' '), mean_r2, std_r2, mean_rot_std, std_rot_std, n_params, mean_time);
end

%% ========================================================================
%% Statistical Significance Tests
%% ========================================================================

fprintf('\n\n');
fprintf('  STATISTICAL ANALYSIS\n');
fprintf('\n\n');

fprintf('Comparing Neural _G vs other methods:\n\n');

baseline_r2 = results.Neural_StarG.test_r2;
baseline_std = results.Neural_StarG.rotation_std;

for m = [1, 3, 4, 5]
    method = methods{m};
    other_r2 = results.(method).test_r2;
    other_std = results.(method).rotation_std;
    
    [~, p_r2] = ttest2(baseline_r2, other_r2);
    var_ratio = mean(other_std) / (mean(baseline_std) + 1e-15);
    
    fprintf('  vs %s:\n', strrep(method, '_', ' '));
    fprintf('    R² difference: %.4f (p = %.4f)\n', mean(baseline_r2) - mean(other_r2), p_r2);
    fprintf('    Variance ratio: %.2e\n\n', var_ratio);
end

%% ========================================================================
%% Generate Publication Figures
%% ========================================================================

fprintf('\n');
fprintf('  Generating Publication Figures\n');
fprintf('\n\n');

fig = figure('Position', [50, 50, 1800, 600], 'Color', 'w');

% Panel A: Test R² Comparison
subplot(1,4,1);
method_labels = {'_G-SVD', 'Neural _G', 'MLP', 'Inv. MLP', 'Aug. MLP'};
mean_r2_all = zeros(n_methods, 1);
std_r2_all = zeros(n_methods, 1);
for m = 1:n_methods
    mean_r2_all(m) = summary.(methods{m}).mean_r2;
    std_r2_all(m) = summary.(methods{m}).std_r2;
end

bar_h = bar(mean_r2_all, 'FaceColor', 'flat');
bar_h.CData(1,:) = [0.2, 0.6, 0.9];
bar_h.CData(2,:) = [0.1, 0.4, 0.8];
bar_h.CData(3,:) = [0.9, 0.3, 0.3];
bar_h.CData(4,:) = [0.9, 0.6, 0.2];
bar_h.CData(5,:) = [0.5, 0.8, 0.3];

hold on;
errorbar(1:n_methods, mean_r2_all, std_r2_all, 'k.', 'LineWidth', 1.5);
set(gca, 'XTickLabel', method_labels, 'XTickLabelRotation', 30);
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('A: Predictive Performance', 'FontSize', 15, 'FontWeight', 'bold');
grid on; box on;
set(gca, 'FontSize', 11);

% Panel B: Rotation Variance
subplot(1,4,2);
mean_rot_std_all = zeros(n_methods, 1);
for m = 1:n_methods
    mean_rot_std_all(m) = max(summary.(methods{m}).mean_rot_std, 1e-16);
end

bar_h2 = bar(mean_rot_std_all, 'FaceColor', 'flat');
bar_h2.CData = bar_h.CData;
set(gca, 'XTickLabel', method_labels, 'XTickLabelRotation', 30);
ylabel('Rotation Variance (σ)', 'FontSize', 13, 'FontWeight', 'bold');
title('B: Rotation Robustness', 'FontSize', 15, 'FontWeight', 'bold');
set(gca, 'YScale', 'log');
grid on; box on;
set(gca, 'FontSize', 11);
ylim([1e-16, 1]);

% Panel C: Parameter Efficiency
subplot(1,4,3);
n_params_all = zeros(n_methods, 1);
for m = 1:n_methods
    n_params_all(m) = summary.(methods{m}).n_params;
end

scatter(n_params_all, mean_r2_all, 200, 'filled');
hold on;
for m = 1:n_methods
    text(n_params_all(m) * 1.1, mean_r2_all(m), method_labels{m}, 'FontSize', 10);
end
xlabel('Number of Parameters', 'FontSize', 13, 'FontWeight', 'bold');
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('C: Parameter Efficiency', 'FontSize', 15, 'FontWeight', 'bold');
set(gca, 'XScale', 'log');
grid on; box on;
set(gca, 'FontSize', 11);

% Panel D: Rotation Performance Profile
subplot(1,4,4);
angles = (0:n_rot-1) * 360 / n_rot;

mean_rot_r2 = zeros(n_methods, n_rot);
for m = 1:n_methods
    mean_rot_r2(m, :) = mean(results.(methods{m}).rotation_r2, 1);
end

plot(angles, mean_rot_r2(1, :), 'b-o', 'LineWidth', 2, 'MarkerSize', 8, 'MarkerFaceColor', 'b');
hold on;
plot(angles, mean_rot_r2(2, :), 'b--s', 'LineWidth', 2, 'MarkerSize', 8);
plot(angles, mean_rot_r2(3, :), 'r-^', 'LineWidth', 2, 'MarkerSize', 8);
plot(angles, mean_rot_r2(5, :), 'g-d', 'LineWidth', 2, 'MarkerSize', 8);

xlabel('Rotation Angle (°)', 'FontSize', 13, 'FontWeight', 'bold');
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('D: Performance vs Rotation', 'FontSize', 15, 'FontWeight', 'bold');
legend('_G-SVD', 'Neural _G', 'MLP', 'Aug. MLP', 'Location', 'best', 'FontSize', 10);
grid on; box on;
set(gca, 'FontSize', 11);

sgtitle('Neural _G vs Equivariant Neural Networks', 'FontSize', 18, 'FontWeight', 'bold');

saveas(fig, fullfile(resultsDir, 'neural_starG_vs_enn.png'));
saveas(fig, fullfile(resultsDir, 'neural_starG_vs_enn.fig'));

fprintf('Figures saved to: %s/\n\n', resultsDir);

%% ========================================================================
%% Comparison Table for Paper
%% ========================================================================

fprintf('\n');
fprintf('  TABLE FOR PUBLICATION\n');
fprintf('\n\n');

fprintf('| Method | Test R² | Rot. Variance | Parameters | Invariant? |\n');
fprintf('|, , --|, , , |, , , , , |, , , , |, , , , |\n');

invariant_flags = {'Yes (exact)', 'Yes (exact)', 'No', 'Yes (approx)', 'No'};

for m = 1:n_methods
    method_name = strrep(methods{m}, '_', ' ');
    fprintf('| %s | %.3f ± %.3f | %.2e | %d | %s |\n', ...
        method_name, summary.(methods{m}).mean_r2, summary.(methods{m}).std_r2, ...
        summary.(methods{m}).mean_rot_std, summary.(methods{m}).n_params, invariant_flags{m});
end

%% ========================================================================
%% Key Findings Summary
%% ========================================================================

fprintf('\n\n');
fprintf('  KEY FINDINGS FOR NATURE COMMUNICATIONS\n');
fprintf('\n\n');

fprintf('1. EXACT INVARIANCE:\n');
fprintf('    _G methods achieve EXACT rotation invariance (σ  0)\n');
fprintf('    Standard MLP: σ = %.4f\n', summary.MLP_r1.mean_rot_std);
fprintf('    Augmented MLP: σ = %.4f\n', summary.Augmented_MLP.mean_rot_std);

fprintf('\n2. PARAMETER EFFICIENCY:\n');
fprintf('    Neural _G: %d parameters\n', summary.Neural_StarG.n_params);
fprintf('    Standard MLP: %d parameters (%.1fx more)\n', ...
    summary.MLP_r1.n_params, summary.MLP_r1.n_params / max(summary.Neural_StarG.n_params, 1));

fprintf('\n3. PREDICTIVE PERFORMANCE:\n');
fprintf('    _G-SVD: R² = %.4f\n', summary.StarG_SVD.mean_r2);
fprintf('    Neural _G: R² = %.4f\n', summary.Neural_StarG.mean_r2);
fprintf('    Best baseline: R² = %.4f\n', max([summary.MLP_r1.mean_r2, summary.Augmented_MLP.mean_r2, summary.Invariant_MLP.mean_r2]));

fprintf('\n\n');
fprintf('  Analysis Complete!\n');
fprintf('\n\n');

save(fullfile(resultsDir, 'comparison_results.mat'), 'results', 'summary', 'methods');

%% ========================================================================
%% HELPER FUNCTIONS
%% ========================================================================

function [X, Y] = generateMolecularData(n_samples, n_feat, n_rot)
    X = zeros(n_samples, n_feat, n_rot);
    Y = zeros(n_samples, 1);
    
    for i = 1:n_samples
        n_atoms = randi([4, 10]);
        pos = randn(n_atoms, 3) * 2;
        pos = pos - mean(pos, 1);
        Z = randi([1, 9], n_atoms, 1);
        Z(Z > 1 & Z < 6) = 6;
        
        dists = pdist(pos);
        if isempty(dists), dists = 1; end
        Y(i) = mean(dists) + 0.3 * std(dists) + sum(Z.^2) / 500;
        
        for g = 1:n_rot
            theta = 2 * pi * (g - 1) / n_rot;
            R = [cos(theta), -sin(theta), 0; sin(theta), cos(theta), 0; 0, 0, 1];
            pos_rot = (R * pos')';
            
            feat = zeros(n_feat, 1);
            for a = 1:n_atoms
                x = pos_rot(a, 1);
                y = pos_rot(a, 2);
                z = pos_rot(a, 3);
                r = norm(pos_rot(a, :));
                
                feat(1) = feat(1) + Z(a) * x;
                feat(2) = feat(2) + Z(a) * y;
                feat(3) = feat(3) + Z(a) * z;
                feat(4) = feat(4) + Z(a) * r;
                feat(5) = feat(5) + Z(a) * r^2;
                
                for f = 6:n_feat
                    feat(f) = feat(f) + Z(a) * cos((f-5) * atan2(y, x)) * exp(-r^2 / 8);
                end
            end
            X(i, :, g) = feat';
        end
    end
end

function X_inv = computeInvariantFeatures(X)
    [n, n_f, ~] = size(X);
    
    X_mean = mean(X, 3);
    X_std = std(X, 0, 3);
    X_min = min(X, [], 3);
    X_max = max(X, [], 3);
    
    X_inv = [X_mean, X_std, X_min, X_max];
end

function r2 = computeR2(y_pred, y_true)
    ss_res = sum((y_pred(:) - y_true(:)).^2);
    ss_tot = sum((y_true(:) - mean(y_true(:))).^2) + 1e-10;
    r2 = 1 - ss_res / ss_tot;
end

function n_params = countMLPParams(layers)
    n_params = 0;
    for l = 1:length(layers)-1
        n_params = n_params + layers(l) * layers(l+1) + layers(l+1);
    end
end

function [W, b] = trainMLP(X_train, Y_train, X_val, Y_val, hidden, epochs, lr, verbose)
    if nargin < 8, verbose = false; end
    
    layers = [size(X_train, 2), hidden, 1];
    nL = length(layers) - 1;
    
    W = cell(nL, 1);
    b = cell(nL, 1);
    for l = 1:nL
        W{l} = randn(layers(l+1), layers(l)) * sqrt(2 / layers(l));
        b{l} = zeros(layers(l+1), 1);
    end
    
    n = size(X_train, 1);
    bs = min(64, n);
    best_loss = inf;
    best_W = W;
    best_b = b;
    wait = 0;
    patience = 15;
    
    for epoch = 1:epochs
        perm = randperm(n);
        for i = 1:bs:n
            idx = perm(i:min(i+bs-1, n));
            Xb = X_train(idx, :);
            Yb = Y_train(idx);
            
            A = Xb';
            As = {A};
            for l = 1:nL
                Z = W{l} * A + b{l};
                A = (l < nL) * max(0, Z) + (l == nL) * Z;
                As{end+1} = A;
            end
            
            dA = 2 * (A - Yb') / length(Yb);
            for l = nL:-1:1
                if l < nL, dA = dA .* (As{l+1} > 0); end
                dW = dA * As{l}' / size(dA, 2);
                db = mean(dA, 2);
                W{l} = W{l} - lr * dW;
                b{l} = b{l} - lr * db;
                if l > 1, dA = W{l}' * dA; end
            end
        end
        
        Yp = forwardMLP(X_val, W, b);
        val_loss = mean((Yp - Y_val).^2);
        if val_loss < best_loss
            best_loss = val_loss;
            best_W = W;
            best_b = b;
            wait = 0;
        else
            wait = wait + 1;
        end
        if wait >= patience, break; end
        
        if verbose && mod(epoch, 20) == 0
            fprintf('  Epoch %d: val_loss = %.4f\n', epoch, val_loss);
        end
    end
    
    W = best_W;
    b = best_b;
end

function Y = forwardMLP(X, W, b)
    A = X';
    for l = 1:length(W)
        Z = W{l} * A + b{l};
        A = (l < length(W)) * max(0, Z) + (l == length(W)) * Z;
    end
    Y = A';
end