%% ========================================================================
%% run_learning_curves.m
%% Generate learning curves: R2 vs n_molecules for all methods
%% This produces the data for Figure 3D (QM9 scaling behavior)
%%
%% Usage:
%% >> run_learning_curves('qm9_dir', './data/xyz/')
%% >> run_learning_curves % synthetic fallback
%%
%% LH & Claude 2026
%% ========================================================================
function run_learning_curves(varargin)
% FORCE WHITE THEME
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
thisDir = fileparts(mfilename('fullpath'));
rootDir = fileparts(thisDir);
addpath(fullfile(rootDir, 'core'));
addpath(fullfile(rootDir, 'experiments'));
p = inputParser;
addParameter(p, 'qm9_dir', '', @ischar);
addParameter(p, 'sample_sizes', [100, 200, 500, 750, 1000], @isnumeric);
addParameter(p, 'n_rotations', 12, @isnumeric);
addParameter(p, 'target_col', 8, @isnumeric);
addParameter(p, 'n_seeds', 3, @isnumeric);
addParameter(p, 'results_dir', 'results/figures', @ischar);
parse(p, varargin{:});
opts = p.Results;
if ~exist(opts.results_dir, 'dir'), mkdir(opts.results_dir); end
sizes = opts.sample_sizes;
n_sizes = length(sizes);
method_names = {'starG_SVD_Ridge', 'Augmented_MLP', 'Standard_MLP', 'Invariant_MLP'};
n_methods = 4;
R2_curves = zeros(n_methods, n_sizes, opts.n_seeds);
fprintf('\n================================================================\n');
fprintf(' Learning Curves: R^2 vs sample size\n');
fprintf('================================================================\n');
for si = 1:n_sizes
n = sizes(si);
fprintf('\n, n_molecules = %d , \n', n);
exp = QM9_experiment(opts.qm9_dir, opts.n_rotations, 'n_molecules', n);
exp = exp.load_data(n);
exp = exp.compute_rotated_features('n_feat', 48);
y = exp.properties_mat(:, opts.target_col);
nf = exp.n_feat_per_rot;
G = exp.G;
for seed = 1:opts.n_seeds
rng(seed*111, 'twister');
idx = randperm(exp.n_molecules);
ntr = round(0.7*exp.n_molecules);
nva = round(0.15*exp.n_molecules);
tri = idx(1:ntr);
vai = idx(ntr+1:ntr+nva);
tei = idx(ntr+nva+1:end);
Xtr = exp.X_tensor(tri,:,:);
Xva = exp.X_tensor(vai,:,:);
Xte = exp.X_tensor(tei,:,:);
ytr = y(tri); yva = y(vai); yte = y(tei);
% 1. starG-SVD + Ridge
[ftr,np] = extractStarGFeatures(Xtr, G, nf);
fva = extractStarGFeatures(Xva, G, nf, np);
fte = extractStarGFeatures(Xte, G, nf, np);
pp = size(ftr,2); R = eye(pp); R(1,1) = 0;
lams = [0.001,0.01,0.1,1,10,100,1000];
be = Inf; bl = 0.01;
for lam = lams
wt = (ftr'*ftr + lam*R) \ (ftr'*ytr);
e = mean((yva - fva*wt).^2);
if e < be, be = e; bl = lam; end
end
w = (ftr'*ftr + bl*R) \ (ftr'*ytr);
yp = fte * w;
R2_curves(1, si, seed) = 1 - sum((yte-yp).^2)/(sum((yte-mean(yte)).^2)+1e-20);
% 2. Augmented MLP
Xa = reshape(permute(Xtr,[1,3,2]),[],nf); ya = repmat(ytr,12,1);
Xav = reshape(permute(Xva,[1,3,2]),[],nf); yav = repmat(yva,12,1);
[mu,sig] = deal(mean(Xa), std(Xa)+1e-8);
[W,B] = train_mlp_simple((Xa-mu)./sig, ya, (Xav-mu)./sig, yav, [nf,64,32,1], 200, 0.003);
yp2 = predict_mlp_simple((squeeze(Xte(:,:,1))-mu)./sig, W, B);
R2_curves(2, si, seed) = 1 - sum((yte-yp2).^2)/(sum((yte-mean(yte)).^2)+1e-20);
% 3. Standard MLP
Xs = squeeze(Xtr(:,:,1)); Xsv = squeeze(Xva(:,:,1)); Xst = squeeze(Xte(:,:,1));
[mu3,s3] = deal(mean(Xs), std(Xs)+1e-8);
[W3,B3] = train_mlp_simple((Xs-mu3)./s3, ytr, (Xsv-mu3)./s3, yva, [nf,64,32,1], 200, 0.003);
yp3 = predict_mlp_simple((Xst-mu3)./s3, W3, B3);
R2_curves(3, si, seed) = 1 - sum((yte-yp3).^2)/(sum((yte-mean(yte)).^2)+1e-20);
% 4. Invariant MLP
fi_tr = [mean(Xtr,3), std(Xtr,0,3)];
fi_va = [mean(Xva,3), std(Xva,0,3)];
fi_te = [mean(Xte,3), std(Xte,0,3)];
[mu4,s4] = deal(mean(fi_tr), std(fi_tr)+1e-8);
[W4,B4] = train_mlp_simple((fi_tr-mu4)./s4, ytr, (fi_va-mu4)./s4, yva, ...
[size(fi_tr,2),64,32,1], 200, 0.003);
yp4 = predict_mlp_simple((fi_te-mu4)./s4, W4, B4);
R2_curves(4, si, seed) = 1 - sum((yte-yp4).^2)/(sum((yte-mean(yte)).^2)+1e-20);
end
for m = 1:n_methods
fprintf(' %s: R2 = %.3f +/- %.3f\n', method_names{m}, ...
mean(R2_curves(m,si,:)), std(R2_curves(m,si,:)));
end
end
% Save results
save(fullfile(opts.results_dir, 'learning_curves.mat'), ...
'R2_curves', 'sizes', 'method_names');
% Plot
fig = figure('Position', [100, 100, 700, 450], 'Color', 'w');
colors = [0.17, 0.63, 0.37; % starG green
0.70, 0.70, 0.70; % augmented gray
0.55, 0.55, 0.55; % standard gray
0.40, 0.40, 0.40]; % invariant gray
markers = {'o', 's', '^', 'd'};
display_names = {'star_G-SVD + Ridge', 'Augmented MLP', 'Standard MLP', 'Invariant MLP'};
for m = 1:n_methods
R2_m = squeeze(mean(R2_curves(m,:,:), 3));
R2_s = squeeze(std(R2_curves(m,:,:), 0, 3));
% Shade the error band
fill([sizes, fliplr(sizes)], ...
[R2_m + R2_s, fliplr(R2_m - R2_s)], ...
colors(m,:), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
hold on;
plot(sizes, R2_m, ['-', markers{m}], 'Color', colors(m,:), ...
'MarkerFaceColor', colors(m,:), 'LineWidth', 2, 'MarkerSize', 8, ...
'DisplayName', display_names{m});
end
yline(0, 'k--', 'LineWidth', 1, 'HandleVisibility', 'off');
xlabel('Number of Molecules', 'FontSize', 13, 'FontWeight', 'bold');
ylabel('Test R^2', 'FontSize', 13, 'FontWeight', 'bold');
title('Learning Curves: R^2 vs Sample Size', 'FontSize', 14, 'FontWeight', 'bold');
legend('Location', 'southeast', 'FontSize', 10);
set(gca, 'FontSize', 11);
grid on; box on;
set(fig, 'InvertHardcopy', 'off');
exportgraphics(fig, fullfile(opts.results_dir, 'fig3d_learning_curves.pdf'), 'ContentType', 'vector', 'BackgroundColor', 'w');
saveas(fig, fullfile(opts.results_dir, 'fig3d_learning_curves.png'));
fprintf('\nFigure saved to %s/fig3d_learning_curves.pdf/.png\n', opts.results_dir);
% Print summary table
fprintf('\n, Learning Curves Summary , \n');
fprintf('%8s', 'n_mol');
for m = 1:n_methods, fprintf(' %18s', method_names{m}); end
fprintf('\n');
for si = 1:n_sizes
fprintf('%8d', sizes(si));
for m = 1:n_methods
fprintf(' %7.3f +/- %5.3f', mean(R2_curves(m,si,:)), std(R2_curves(m,si,:)));
end
fprintf('\n');
end
end
%% Standalone MLP functions (no class dependency)
function [W, B] = train_mlp_simple(X, y, Xv, yv, layers, maxep, lr)
nl = length(layers)-1;
W = cell(nl,1); B = cell(nl,1);
for l = 1:nl
fi = layers(l);
W{l} = randn(fi, layers(l+1)) * sqrt(2/fi);
B{l} = zeros(1, layers(l+1));
end
mW = cellfun(@(w) zeros(size(w)), W, 'Uni', 0); vW = mW;
mB = cellfun(@(b) zeros(size(b)), B, 'Uni', 0); vB = mB;
b1=0.9; b2=0.999; ea=1e-8;
bv=Inf; pa=20; wa=0; Wb=W; Bb=B;
nn=size(X,1); bs=min(256,nn);
for ep = 1:maxep
pm = randperm(nn);
for s = 1:bs:nn
bi = pm(s:min(s+bs-1,nn)); Xb=X(bi,:); yb=y(bi);
A = cell(nl+1,1); A{1}=Xb;
for l = 1:nl
Zl = A{l}*W{l}+B{l};
if l<nl, A{l+1}=max(0,Zl); else, A{l+1}=Zl; end
end
dZ = (A{nl+1}-yb)/size(Xb,1);
for l = nl:-1:1
gW=A{l}'*dZ; gB=sum(dZ,1);
if l>1, dZ=(dZ*W{l}').*(A{l}>0); end
mW{l}=b1*mW{l}+(1-b1)*gW; vW{l}=b2*vW{l}+(1-b2)*gW.^2;
mB{l}=b1*mB{l}+(1-b1)*gB; vB{l}=b2*vB{l}+(1-b2)*gB.^2;
W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
end
end
yp=Xv; for l=1:nl, yp=yp*W{l}+B{l}; if l<nl, yp=max(0,yp); end; end
vm=mean((yv-yp).^2);
if vm<bv, bv=vm; wa=0; Wb=W; Bb=B; else, wa=wa+1; if wa>=pa, break; end; end
end
W=Wb; B=Bb;
end
function yp = predict_mlp_simple(X, W, B)
yp = X;
for l = 1:length(W)
yp = yp * W{l} + B{l};
if l < length(W), yp = max(0, yp); end
end
end