%% run_neural_starG_vs_ENN.m
%% Neural ★_G vs Equivariant Neural Network Comparison
%% For Nature Communications
%% ============================================================================
%%
%% This script compares:
%% 1. ★_G-SVD + Ridge (linear, exact invariance) - OUR METHOD
%% 2. Neural ★_G (non-linear, exact invariance) - OUR METHOD
%% 3. Standard MLP (non-linear, NO invariance) - BASELINE
%% 4. Invariant MLP (non-linear, approximate invariance) - ENN PROXY
%% 5. Augmented MLP (non-linear, learned invariance) - ENN PROXY
%%
%% LH & SU & Claude 2026
%% ============================================================================
clear; clc; close all;
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' Neural ★_G vs Equivariant Neural Networks\n');
fprintf(' Comprehensive Comparison for Nature Communications\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
resultsDir = 'neural_vs_enn_results';
if ~exist(resultsDir, 'dir'), mkdir(resultsDir); end
%% ========================================================================
%% Configuration
%% ========================================================================
n_samples = 1500;
n_rot = 12;
n_feat = 20;
n_runs = 3;
fprintf('Configuration:\n');
fprintf(' Samples: %d\n', n_samples);
fprintf(' Rotations: %d\n', n_rot);
fprintf(' Features: %d\n', n_feat);
fprintf(' Runs: %d\n\n', n_runs);
%% ========================================================================
%% Generate Data
%% ========================================================================
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' Generating Molecular Dataset\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
G = StarGAlgebra('cyclic', n_rot);
[X_data, Y_data] = generateMolecularData(n_samples, n_feat, n_rot);
% Normalize
X_data = (X_data - mean(X_data(:))) / (std(X_data(:)) + 1e-10);
Y_mean = mean(Y_data);
Y_std = std(Y_data) + 1e-10;
Y_data = (Y_data - Y_mean) / Y_std;
fprintf('Data shape: [%d, %d, %d]\n', size(X_data));
fprintf('Target range: [%.2f, %.2f]\n\n', min(Y_data), max(Y_data));
%% ========================================================================
%% Train/Val/Test Split
%% ========================================================================
rng(42);
n_train = round(0.7 * n_samples);
n_val = round(0.15 * n_samples);
n_test = n_samples - n_train - n_val;
perm = randperm(n_samples);
train_idx = perm(1:n_train);
val_idx = perm(n_train+1:n_train+n_val);
test_idx = perm(n_train+n_val+1:end);
X_train = X_data(train_idx, :, :);
X_val = X_data(val_idx, :, :);
X_test = X_data(test_idx, :, :);
Y_train = Y_data(train_idx);
Y_val = Y_data(val_idx);
Y_test = Y_data(test_idx);
fprintf('Split: Train=%d, Val=%d, Test=%d\n\n', n_train, n_val, n_test);
%% ========================================================================
%% Verify Exact Feature Invariance
%% ========================================================================
fprintf('╔══════════════════════════════════════════════════════════════════╗\n');
fprintf('║ Verifying Exact Feature Invariance ║\n');
fprintf('╚══════════════════════════════════════════════════════════════════╝\n\n');
n_verify = min(10, size(X_train, 1));
test_samples = X_train(1:n_verify, :, :);
[feat_batch, norm_params_test] = extractStarGFeatures(test_samples, G, n_feat);
test_sample = X_train(1:1, :, :);
feat_orig = extractStarGFeatures(test_sample, G, n_feat, norm_params_test);
fprintf('Test sample shape: [%d, %d, %d]\n', size(test_sample));
fprintf('Feature vector length: %d\n\n', size(feat_orig, 2));
fprintf('Rotation Angle Max |Δfeature|\n');
fprintf(', , -- , -- , , , , --\n');
max_diff_all = 0;
for g = 1:n_rot
test_rot = circshift(test_sample, g-1, 3);
feat_rot = extractStarGFeatures(test_rot, G, n_feat, norm_params_test);
diff = max(abs(feat_orig(:) - feat_rot(:)));
max_diff_all = max(max_diff_all, diff);
angle = (g-1) * 360 / n_rot;
fprintf(' %2d %3d° %.2e\n', g, angle, diff);
end
fprintf('\nMaximum difference: %.2e\n', max_diff_all);
if max_diff_all < 1e-12
fprintf('✓ Features are EXACTLY invariant (machine precision)\n\n');
elseif max_diff_all < 1e-8
fprintf('✓ Features are effectively invariant\n\n');
else
fprintf('⚠ WARNING: Features are NOT invariant!\n\n');
end
%% ========================================================================
%% Initialize Results Storage
%% ========================================================================
methods = {'StarG_SVD', 'Neural_StarG', 'MLP_r1', 'Invariant_MLP', 'Augmented_MLP'};
n_methods = length(methods);
results = struct();
for m = 1:n_methods
results.(methods{m}).test_r2 = zeros(n_runs, 1);
results.(methods{m}).rotation_r2 = zeros(n_runs, n_rot);
results.(methods{m}).rotation_std = zeros(n_runs, 1);
results.(methods{m}).n_params = 0;
results.(methods{m}).train_time = zeros(n_runs, 1);
end
%% ========================================================================
%% Run Multiple Experiments
%% ========================================================================
for run = 1:n_runs
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' RUN %d / %d\n', run, n_runs);
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
rng(run * 123);
%% ====================================================================
%% Method 1: ★_G-SVD + Ridge (Exact Invariance)
%% ====================================================================
fprintf(', Method 1: ★_G-SVD + Ridge (Exact Invariance) , \n');
tic;
[train_feat, norm_params] = extractStarGFeatures(X_train, G, n_feat);
val_feat = extractStarGFeatures(X_val, G, n_feat, norm_params);
test_feat = extractStarGFeatures(X_test, G, n_feat, norm_params);
n_extracted = size(train_feat, 2);
fprintf(' Exact invariant features: %d\n', n_extracted);
lambdas = [0.001, 0.01, 0.1, 1, 10];
best_lambda = 0.1;
best_val_r2 = -inf;
for lam = lambdas
W_temp = (train_feat' * train_feat + lam * eye(n_extracted)) \ (train_feat' * Y_train);
Y_val_pred = val_feat * W_temp;
val_r2 = computeR2(Y_val_pred, Y_val);
if val_r2 > best_val_r2
best_val_r2 = val_r2;
best_lambda = lam;
end
end
fprintf(' Best lambda: %.4f\n', best_lambda);
W_svd = (train_feat' * train_feat + best_lambda * eye(n_extracted)) \ (train_feat' * Y_train);
results.StarG_SVD.train_time(run) = toc;
Y_pred = test_feat * W_svd;
results.StarG_SVD.test_r2(run) = computeR2(Y_pred, Y_test);
results.StarG_SVD.n_params = n_extracted;
rot_r2 = zeros(n_rot, 1);
for g = 1:n_rot
X_test_rot = circshift(X_test, g-1, 3);
feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
Y_pred_g = feat_g * W_svd;
rot_r2(g) = computeR2(Y_pred_g, Y_test);
end
results.StarG_SVD.rotation_r2(run, :) = rot_r2';
results.StarG_SVD.rotation_std(run) = std(rot_r2);
fprintf(' R² = %.4f, σ = %.2e, params = %d\n', ...
results.StarG_SVD.test_r2(run), results.StarG_SVD.rotation_std(run), n_extracted);
if results.StarG_SVD.rotation_std(run) < 1e-10
fprintf(' ✓ EXACT invariance achieved!\n\n');
else
fprintf(' Rotation variance: %.2e\n\n', results.StarG_SVD.rotation_std(run));
end
%% ====================================================================
%% Method 2: Neural ★_G (Exact Invariance)
%% ====================================================================
fprintf(', Method 2: Neural ★_G (Exact Invariance) , \n');
tic;
train_feat_nn = train_feat;
val_feat_nn = val_feat;
test_feat_nn = test_feat;
n_pca = min(50, n_extracted);
if n_extracted > 50
[coeff_pca, ~, ~] = pca(train_feat);
train_feat_nn = train_feat * coeff_pca(:, 1:n_pca);
val_feat_nn = val_feat * coeff_pca(:, 1:n_pca);
test_feat_nn = test_feat * coeff_pca(:, 1:n_pca);
pca_coeff = coeff_pca(:, 1:n_pca);
use_pca = true;
else
pca_coeff = eye(n_extracted);
n_pca = n_extracted;
use_pca = false;
end
fprintf(' Features for MLP: %d\n', n_pca);
hidden_layers = [32, 16];
[W_nn, b_nn] = trainMLP(train_feat_nn, Y_train, val_feat_nn, Y_val, ...
hidden_layers, 200, 0.005, false);
results.Neural_StarG.train_time(run) = toc;
Y_pred = forwardMLP(test_feat_nn, W_nn, b_nn);
results.Neural_StarG.test_r2(run) = computeR2(Y_pred, Y_test);
results.Neural_StarG.n_params = countMLPParams([n_pca, hidden_layers, 1]);
rot_r2 = zeros(n_rot, 1);
for g = 1:n_rot
X_test_rot = circshift(X_test, g-1, 3);
feat_g = extractStarGFeatures(X_test_rot, G, n_feat, norm_params);
if use_pca
feat_g = feat_g * pca_coeff;
end
Y_pred_g = forwardMLP(feat_g, W_nn, b_nn);
rot_r2(g) = computeR2(Y_pred_g, Y_test);
end
results.Neural_StarG.rotation_r2(run, :) = rot_r2';
results.Neural_StarG.rotation_std(run) = std(rot_r2);
fprintf(' R² = %.4f, σ = %.2e, params = %d\n', ...
results.Neural_StarG.test_r2(run), results.Neural_StarG.rotation_std(run), ...
results.Neural_StarG.n_params);
if results.Neural_StarG.rotation_std(run) < 1e-10
fprintf(' ✓ EXACT invariance achieved!\n\n');
else
fprintf(' Rotation variance: %.2e\n\n', results.Neural_StarG.rotation_std(run));
end
%% ====================================================================
%% Method 3: Standard MLP (No Invariance)
%% ====================================================================
fprintf(', Method 3: Standard MLP (rotation 1 only) , \n');
X_train_r1 = squeeze(X_train(:, :, 1));
X_val_r1 = squeeze(X_val(:, :, 1));
X_test_r1 = squeeze(X_test(:, :, 1));
tic;
[W_mlp, b_mlp] = trainMLP(X_train_r1, Y_train, X_val_r1, Y_val, [64, 32], 100, 0.005, false);
results.MLP_r1.train_time(run) = toc;
Y_pred = forwardMLP(X_test_r1, W_mlp, b_mlp);
results.MLP_r1.test_r2(run) = computeR2(Y_pred, Y_test);
results.MLP_r1.n_params = countMLPParams([n_feat, 64, 32, 1]);
for g = 1:n_rot
X_test_g = squeeze(X_test(:, :, g));
Y_pred_g = forwardMLP(X_test_g, W_mlp, b_mlp);
results.MLP_r1.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
end
results.MLP_r1.rotation_std(run) = std(results.MLP_r1.rotation_r2(run, :));
fprintf(' R² = %.4f, σ = %.4f, params = %d\n\n', ...
results.MLP_r1.test_r2(run), results.MLP_r1.rotation_std(run), results.MLP_r1.n_params);
%% ====================================================================
%% Method 4: Invariant MLP (ENN Proxy)
%% ====================================================================
fprintf(', Method 4: Invariant MLP (ENN-style) , \n');
X_train_inv = computeInvariantFeatures(X_train);
X_val_inv = computeInvariantFeatures(X_val);
X_test_inv = computeInvariantFeatures(X_test);
n_inv_feat = size(X_train_inv, 2);
tic;
[W_inv, b_inv] = trainMLP(X_train_inv, Y_train, X_val_inv, Y_val, [64, 32], 100, 0.005, false);
results.Invariant_MLP.train_time(run) = toc;
Y_pred = forwardMLP(X_test_inv, W_inv, b_inv);
results.Invariant_MLP.test_r2(run) = computeR2(Y_pred, Y_test);
results.Invariant_MLP.n_params = countMLPParams([n_inv_feat, 64, 32, 1]);
for g = 1:n_rot
X_test_rot = circshift(X_test, g-1, 3);
X_test_inv_g = computeInvariantFeatures(X_test_rot);
Y_pred_g = forwardMLP(X_test_inv_g, W_inv, b_inv);
results.Invariant_MLP.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
end
results.Invariant_MLP.rotation_std(run) = std(results.Invariant_MLP.rotation_r2(run, :));
fprintf(' R² = %.4f, σ = %.2e, params = %d\n\n', ...
results.Invariant_MLP.test_r2(run), results.Invariant_MLP.rotation_std(run), ...
results.Invariant_MLP.n_params);
%% ====================================================================
%% Method 5: Augmented MLP (Data Augmentation)
%% ====================================================================
fprintf(', Method 5: MLP with Full Augmentation , \n');
X_train_aug = zeros(n_train * n_rot, n_feat);
Y_train_aug = zeros(n_train * n_rot, 1);
for i = 1:n_train
for g = 1:n_rot
idx = (i-1) * n_rot + g;
X_train_aug(idx, :) = X_train(i, :, g);
Y_train_aug(idx) = Y_train(i);
end
end
tic;
[W_aug, b_aug] = trainMLP(X_train_aug, Y_train_aug, X_val_r1, Y_val, [64, 32], 80, 0.005, false);
results.Augmented_MLP.train_time(run) = toc;
Y_pred = forwardMLP(X_test_r1, W_aug, b_aug);
results.Augmented_MLP.test_r2(run) = computeR2(Y_pred, Y_test);
results.Augmented_MLP.n_params = countMLPParams([n_feat, 64, 32, 1]);
for g = 1:n_rot
X_test_g = squeeze(X_test(:, :, g));
Y_pred_g = forwardMLP(X_test_g, W_aug, b_aug);
results.Augmented_MLP.rotation_r2(run, g) = computeR2(Y_pred_g, Y_test);
end
results.Augmented_MLP.rotation_std(run) = std(results.Augmented_MLP.rotation_r2(run, :));
fprintf(' R² = %.4f, σ = %.4f, params = %d\n\n', ...
results.Augmented_MLP.test_r2(run), results.Augmented_MLP.rotation_std(run), ...
results.Augmented_MLP.n_params);
end
%% ========================================================================
%% Aggregate Results
%% ========================================================================
fprintf('\n');
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' AGGREGATED RESULTS (Mean ± Std over %d runs)\n', n_runs);
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
fprintf('%-20s %-15s %-20s %-10s %-12s\n', ...
'Method', 'Test R²', 'Rotation σ', 'Params', 'Time (s)');
fprintf('%s\n', repmat('-', 1, 85));
summary = struct();
for m = 1:n_methods
method = methods{m};
mean_r2 = mean(results.(method).test_r2);
std_r2 = std(results.(method).test_r2);
mean_rot_std = mean(results.(method).rotation_std);
std_rot_std = std(results.(method).rotation_std);
mean_time = mean(results.(method).train_time);
n_params = results.(method).n_params;
summary.(method).mean_r2 = mean_r2;
summary.(method).std_r2 = std_r2;
summary.(method).mean_rot_std = mean_rot_std;
summary.(method).n_params = n_params;
summary.(method).mean_time = mean_time;
fprintf('%-20s %.4f ± %.4f %.2e ± %.2e %-10d %.2f\n', ...
strrep(method, '_', ' '), mean_r2, std_r2, mean_rot_std, std_rot_std, n_params, mean_time);
end
%% ========================================================================
%% Statistical Significance Tests
%% ========================================================================
fprintf('\n═══════════════════════════════════════════════════════════════════\n');
fprintf(' STATISTICAL ANALYSIS\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
fprintf('Comparing Neural ★_G vs other methods:\n\n');
baseline_r2 = results.Neural_StarG.test_r2;
baseline_std = results.Neural_StarG.rotation_std;
for m = [1, 3, 4, 5]
method = methods{m};
other_r2 = results.(method).test_r2;
other_std = results.(method).rotation_std;
[~, p_r2] = ttest2(baseline_r2, other_r2);
var_ratio = mean(other_std) / (mean(baseline_std) + 1e-15);
fprintf(' vs %s:\n', strrep(method, '_', ' '));
fprintf(' R² difference: %.4f (p = %.4f)\n', mean(baseline_r2) - mean(other_r2), p_r2);
fprintf(' Variance ratio: %.2e\n\n', var_ratio);
end
%% ========================================================================
%% Generate Publication Figures
%% ========================================================================
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' Generating Publication Figures\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
fig = figure('Position', [50, 50, 1800, 600], 'Color', 'w');
% Panel A: Test R² Comparison
subplot(1,4,1);
method_labels = {'★_G-SVD', 'Neural ★_G', 'MLP', 'Inv. MLP', 'Aug. MLP'};
mean_r2_all = zeros(n_methods, 1);
std_r2_all = zeros(n_methods, 1);
for m = 1:n_methods
mean_r2_all(m) = summary.(methods{m}).mean_r2;
std_r2_all(m) = summary.(methods{m}).std_r2;
end
bar_h = bar(mean_r2_all, 'FaceColor', 'flat');
bar_h.CData(1,:) = [0.2, 0.6, 0.9];
bar_h.CData(2,:) = [0.1, 0.4, 0.8];
bar_h.CData(3,:) = [0.9, 0.3, 0.3];
bar_h.CData(4,:) = [0.9, 0.6, 0.2];
bar_h.CData(5,:) = [0.5, 0.8, 0.3];
hold on;
errorbar(1:n_methods, mean_r2_all, std_r2_all, 'k.', 'LineWidth', 1.5);
set(gca, 'XTickLabel', method_labels, 'XTickLabelRotation', 30);
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('A: Predictive Performance', 'FontSize', 15, 'FontWeight', 'bold');
grid on; box on;
set(gca, 'FontSize', 11);
% Panel B: Rotation Variance
subplot(1,4,2);
mean_rot_std_all = zeros(n_methods, 1);
for m = 1:n_methods
mean_rot_std_all(m) = max(summary.(methods{m}).mean_rot_std, 1e-16);
end
bar_h2 = bar(mean_rot_std_all, 'FaceColor', 'flat');
bar_h2.CData = bar_h.CData;
set(gca, 'XTickLabel', method_labels, 'XTickLabelRotation', 30);
ylabel('Rotation Variance (σ)', 'FontSize', 13, 'FontWeight', 'bold');
title('B: Rotation Robustness', 'FontSize', 15, 'FontWeight', 'bold');
set(gca, 'YScale', 'log');
grid on; box on;
set(gca, 'FontSize', 11);
ylim([1e-16, 1]);
% Panel C: Parameter Efficiency
subplot(1,4,3);
n_params_all = zeros(n_methods, 1);
for m = 1:n_methods
n_params_all(m) = summary.(methods{m}).n_params;
end
scatter(n_params_all, mean_r2_all, 200, 'filled');
hold on;
for m = 1:n_methods
text(n_params_all(m) * 1.1, mean_r2_all(m), method_labels{m}, 'FontSize', 10);
end
xlabel('Number of Parameters', 'FontSize', 13, 'FontWeight', 'bold');
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('C: Parameter Efficiency', 'FontSize', 15, 'FontWeight', 'bold');
set(gca, 'XScale', 'log');
grid on; box on;
set(gca, 'FontSize', 11);
% Panel D: Rotation Performance Profile
subplot(1,4,4);
angles = (0:n_rot-1) * 360 / n_rot;
mean_rot_r2 = zeros(n_methods, n_rot);
for m = 1:n_methods
mean_rot_r2(m, :) = mean(results.(methods{m}).rotation_r2, 1);
end
plot(angles, mean_rot_r2(1, :), 'b-o', 'LineWidth', 2, 'MarkerSize', 8, 'MarkerFaceColor', 'b');
hold on;
plot(angles, mean_rot_r2(2, :), 'b--s', 'LineWidth', 2, 'MarkerSize', 8);
plot(angles, mean_rot_r2(3, :), 'r-^', 'LineWidth', 2, 'MarkerSize', 8);
plot(angles, mean_rot_r2(5, :), 'g-d', 'LineWidth', 2, 'MarkerSize', 8);
xlabel('Rotation Angle (°)', 'FontSize', 13, 'FontWeight', 'bold');
ylabel('Test R²', 'FontSize', 13, 'FontWeight', 'bold');
title('D: Performance vs Rotation', 'FontSize', 15, 'FontWeight', 'bold');
legend('★_G-SVD', 'Neural ★_G', 'MLP', 'Aug. MLP', 'Location', 'best', 'FontSize', 10);
grid on; box on;
set(gca, 'FontSize', 11);
sgtitle('Neural ★_G vs Equivariant Neural Networks', 'FontSize', 18, 'FontWeight', 'bold');
saveas(fig, fullfile(resultsDir, 'neural_starG_vs_enn.png'));
saveas(fig, fullfile(resultsDir, 'neural_starG_vs_enn.fig'));
fprintf('Figures saved to: %s/\n\n', resultsDir);
%% ========================================================================
%% Comparison Table for Paper
%% ========================================================================
fprintf('═══════════════════════════════════════════════════════════════════\n');
fprintf(' TABLE FOR PUBLICATION\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
fprintf('| Method | Test R² | Rot. Variance | Parameters | Invariant? |\n');
fprintf('|, , --|, , , |, , , , , |, , , , |, , , , |\n');
invariant_flags = {'Yes (exact)', 'Yes (exact)', 'No', 'Yes (approx)', 'No'};
for m = 1:n_methods
method_name = strrep(methods{m}, '_', ' ');
fprintf('| %s | %.3f ± %.3f | %.2e | %d | %s |\n', ...
method_name, summary.(methods{m}).mean_r2, summary.(methods{m}).std_r2, ...
summary.(methods{m}).mean_rot_std, summary.(methods{m}).n_params, invariant_flags{m});
end
%% ========================================================================
%% Key Findings Summary
%% ========================================================================
fprintf('\n═══════════════════════════════════════════════════════════════════\n');
fprintf(' KEY FINDINGS FOR NATURE COMMUNICATIONS\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
fprintf('1. EXACT INVARIANCE:\n');
fprintf(' • ★_G methods achieve EXACT rotation invariance (σ ≈ 0)\n');
fprintf(' • Standard MLP: σ = %.4f\n', summary.MLP_r1.mean_rot_std);
fprintf(' • Augmented MLP: σ = %.4f\n', summary.Augmented_MLP.mean_rot_std);
fprintf('\n2. PARAMETER EFFICIENCY:\n');
fprintf(' • Neural ★_G: %d parameters\n', summary.Neural_StarG.n_params);
fprintf(' • Standard MLP: %d parameters (%.1fx more)\n', ...
summary.MLP_r1.n_params, summary.MLP_r1.n_params / max(summary.Neural_StarG.n_params, 1));
fprintf('\n3. PREDICTIVE PERFORMANCE:\n');
fprintf(' • ★_G-SVD: R² = %.4f\n', summary.StarG_SVD.mean_r2);
fprintf(' • Neural ★_G: R² = %.4f\n', summary.Neural_StarG.mean_r2);
fprintf(' • Best baseline: R² = %.4f\n', max([summary.MLP_r1.mean_r2, summary.Augmented_MLP.mean_r2, summary.Invariant_MLP.mean_r2]));
fprintf('\n═══════════════════════════════════════════════════════════════════\n');
fprintf(' Analysis Complete!\n');
fprintf('═══════════════════════════════════════════════════════════════════\n\n');
save(fullfile(resultsDir, 'comparison_results.mat'), 'results', 'summary', 'methods');
%% ========================================================================
%% HELPER FUNCTIONS
%% ========================================================================
function [X, Y] = generateMolecularData(n_samples, n_feat, n_rot)
X = zeros(n_samples, n_feat, n_rot);
Y = zeros(n_samples, 1);
for i = 1:n_samples
n_atoms = randi([4, 10]);
pos = randn(n_atoms, 3) * 2;
pos = pos - mean(pos, 1);
Z = randi([1, 9], n_atoms, 1);
Z(Z > 1 & Z < 6) = 6;
dists = pdist(pos);
if isempty(dists), dists = 1; end
Y(i) = mean(dists) + 0.3 * std(dists) + sum(Z.^2) / 500;
for g = 1:n_rot
theta = 2 * pi * (g - 1) / n_rot;
R = [cos(theta), -sin(theta), 0; sin(theta), cos(theta), 0; 0, 0, 1];
pos_rot = (R * pos')';
feat = zeros(n_feat, 1);
for a = 1:n_atoms
x = pos_rot(a, 1);
y = pos_rot(a, 2);
z = pos_rot(a, 3);
r = norm(pos_rot(a, :));
feat(1) = feat(1) + Z(a) * x;
feat(2) = feat(2) + Z(a) * y;
feat(3) = feat(3) + Z(a) * z;
feat(4) = feat(4) + Z(a) * r;
feat(5) = feat(5) + Z(a) * r^2;
for f = 6:n_feat
feat(f) = feat(f) + Z(a) * cos((f-5) * atan2(y, x)) * exp(-r^2 / 8);
end
end
X(i, :, g) = feat';
end
end
end
function X_inv = computeInvariantFeatures(X)
[n, n_f, ~] = size(X);
X_mean = mean(X, 3);
X_std = std(X, 0, 3);
X_min = min(X, [], 3);
X_max = max(X, [], 3);
X_inv = [X_mean, X_std, X_min, X_max];
end
function r2 = computeR2(y_pred, y_true)
ss_res = sum((y_pred(:) - y_true(:)).^2);
ss_tot = sum((y_true(:) - mean(y_true(:))).^2) + 1e-10;
r2 = 1 - ss_res / ss_tot;
end
function n_params = countMLPParams(layers)
n_params = 0;
for l = 1:length(layers)-1
n_params = n_params + layers(l) * layers(l+1) + layers(l+1);
end
end
function [W, b] = trainMLP(X_train, Y_train, X_val, Y_val, hidden, epochs, lr, verbose)
if nargin < 8, verbose = false; end
layers = [size(X_train, 2), hidden, 1];
nL = length(layers) - 1;
W = cell(nL, 1);
b = cell(nL, 1);
for l = 1:nL
W{l} = randn(layers(l+1), layers(l)) * sqrt(2 / layers(l));
b{l} = zeros(layers(l+1), 1);
end
n = size(X_train, 1);
bs = min(64, n);
best_loss = inf;
best_W = W;
best_b = b;
wait = 0;
patience = 15;
for epoch = 1:epochs
perm = randperm(n);
for i = 1:bs:n
idx = perm(i:min(i+bs-1, n));
Xb = X_train(idx, :);
Yb = Y_train(idx);
A = Xb';
As = {A};
for l = 1:nL
Z = W{l} * A + b{l};
A = (l < nL) * max(0, Z) + (l == nL) * Z;
As{end+1} = A;
end
dA = 2 * (A - Yb') / length(Yb);
for l = nL:-1:1
if l < nL, dA = dA .* (As{l+1} > 0); end
dW = dA * As{l}' / size(dA, 2);
db = mean(dA, 2);
W{l} = W{l} - lr * dW;
b{l} = b{l} - lr * db;
if l > 1, dA = W{l}' * dA; end
end
end
Yp = forwardMLP(X_val, W, b);
val_loss = mean((Yp - Y_val).^2);
if val_loss < best_loss
best_loss = val_loss;
best_W = W;
best_b = b;
wait = 0;
else
wait = wait + 1;
end
if wait >= patience, break; end
if verbose && mod(epoch, 20) == 0
fprintf(' Epoch %d: val_loss = %.4f\n', epoch, val_loss);
end
end
W = best_W;
b = best_b;
end
function Y = forwardMLP(X, W, b)
A = X';
for l = 1:length(W)
Z = W{l} * A + b{l};
A = (l < length(W)) * max(0, Z) + (l == length(W)) * Z;
end
Y = A';
end