tensor-group-sym / experiments / run_learning_curves.m
run_learning_curves.m
Raw
%% ========================================================================
%% run_learning_curves.m
%% Generate learning curves: R2 vs n_molecules for all methods
%% This produces the data for Figure 3D (QM9 scaling behavior)
%%
%% Usage:
%%   >> run_learning_curves('qm9_dir', './data/xyz/')
%%   >> run_learning_curves   % synthetic fallback
%%
%% LH & Claude 2026
%% ========================================================================

function run_learning_curves(varargin)
    % FORCE WHITE THEME
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);


    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    p = inputParser;
    addParameter(p, 'qm9_dir', '', @ischar);
    addParameter(p, 'sample_sizes', [100, 200, 500, 750, 1000], @isnumeric);
    addParameter(p, 'n_rotations', 12, @isnumeric);
    addParameter(p, 'target_col', 8, @isnumeric);
    addParameter(p, 'n_seeds', 3, @isnumeric);
    addParameter(p, 'results_dir', 'results/figures', @ischar);
    parse(p, varargin{:});
    opts = p.Results;

    if ~exist(opts.results_dir, 'dir'), mkdir(opts.results_dir); end

    sizes = opts.sample_sizes;
    n_sizes = length(sizes);
    method_names = {'starG_SVD_Ridge', 'Augmented_MLP', 'Standard_MLP', 'Invariant_MLP'};
    n_methods = 4;

    R2_curves = zeros(n_methods, n_sizes, opts.n_seeds);

    fprintf('\n================================================================\n');
    fprintf('  Learning Curves: R^2 vs sample size\n');
    fprintf('================================================================\n');

    for si = 1:n_sizes
        n = sizes(si);
        fprintf('\n,  n_molecules = %d , \n', n);

        exp = QM9_experiment(opts.qm9_dir, opts.n_rotations, 'n_molecules', n);
        exp = exp.load_data(n);
        exp = exp.compute_rotated_features('n_feat', 48);

        y = exp.properties_mat(:, opts.target_col);
        nf = exp.n_feat_per_rot;
        G = exp.G;

        for seed = 1:opts.n_seeds
            rng(seed*111, 'twister');
            idx = randperm(exp.n_molecules);
            ntr = round(0.7*exp.n_molecules);
            nva = round(0.15*exp.n_molecules);
            tri = idx(1:ntr);
            vai = idx(ntr+1:ntr+nva);
            tei = idx(ntr+nva+1:end);

            Xtr = exp.X_tensor(tri,:,:);
            Xva = exp.X_tensor(vai,:,:);
            Xte = exp.X_tensor(tei,:,:);
            ytr = y(tri); yva = y(vai); yte = y(tei);

            % 1. starG-SVD + Ridge
            [ftr,np] = extractStarGFeatures(Xtr, G, nf);
            fva = extractStarGFeatures(Xva, G, nf, np);
            fte = extractStarGFeatures(Xte, G, nf, np);
            pp = size(ftr,2); R = eye(pp); R(1,1) = 0;
            lams = [0.001,0.01,0.1,1,10,100,1000];
            be = Inf; bl = 0.01;
            for lam = lams
                wt = (ftr'*ftr + lam*R) \ (ftr'*ytr);
                e = mean((yva - fva*wt).^2);
                if e < be, be = e; bl = lam; end
            end
            w = (ftr'*ftr + bl*R) \ (ftr'*ytr);
            yp = fte * w;
            R2_curves(1, si, seed) = 1 - sum((yte-yp).^2)/(sum((yte-mean(yte)).^2)+1e-20);

            % 2. Augmented MLP
            Xa = reshape(permute(Xtr,[1,3,2]),[],nf); ya = repmat(ytr,12,1);
            Xav = reshape(permute(Xva,[1,3,2]),[],nf); yav = repmat(yva,12,1);
            [mu,sig] = deal(mean(Xa), std(Xa)+1e-8);
            [W,B] = train_mlp_simple((Xa-mu)./sig, ya, (Xav-mu)./sig, yav, [nf,64,32,1], 200, 0.003);
            yp2 = predict_mlp_simple((squeeze(Xte(:,:,1))-mu)./sig, W, B);
            R2_curves(2, si, seed) = 1 - sum((yte-yp2).^2)/(sum((yte-mean(yte)).^2)+1e-20);

            % 3. Standard MLP
            Xs = squeeze(Xtr(:,:,1)); Xsv = squeeze(Xva(:,:,1)); Xst = squeeze(Xte(:,:,1));
            [mu3,s3] = deal(mean(Xs), std(Xs)+1e-8);
            [W3,B3] = train_mlp_simple((Xs-mu3)./s3, ytr, (Xsv-mu3)./s3, yva, [nf,64,32,1], 200, 0.003);
            yp3 = predict_mlp_simple((Xst-mu3)./s3, W3, B3);
            R2_curves(3, si, seed) = 1 - sum((yte-yp3).^2)/(sum((yte-mean(yte)).^2)+1e-20);

            % 4. Invariant MLP
            fi_tr = [mean(Xtr,3), std(Xtr,0,3)];
            fi_va = [mean(Xva,3), std(Xva,0,3)];
            fi_te = [mean(Xte,3), std(Xte,0,3)];
            [mu4,s4] = deal(mean(fi_tr), std(fi_tr)+1e-8);
            [W4,B4] = train_mlp_simple((fi_tr-mu4)./s4, ytr, (fi_va-mu4)./s4, yva, ...
                [size(fi_tr,2),64,32,1], 200, 0.003);
            yp4 = predict_mlp_simple((fi_te-mu4)./s4, W4, B4);
            R2_curves(4, si, seed) = 1 - sum((yte-yp4).^2)/(sum((yte-mean(yte)).^2)+1e-20);
        end

        for m = 1:n_methods
            fprintf('  %s: R2 = %.3f +/- %.3f\n', method_names{m}, ...
                mean(R2_curves(m,si,:)), std(R2_curves(m,si,:)));
        end
    end

    % Save results
    save(fullfile(opts.results_dir, 'learning_curves.mat'), ...
        'R2_curves', 'sizes', 'method_names');

    % Plot
    fig = figure('Position', [100, 100, 700, 450], 'Color', 'w');

    colors = [0.17, 0.63, 0.37;   % starG green
              0.70, 0.70, 0.70;   % augmented gray
              0.55, 0.55, 0.55;   % standard gray
              0.40, 0.40, 0.40];  % invariant gray
    markers = {'o', 's', '^', 'd'};
    display_names = {'star_G-SVD + Ridge', 'Augmented MLP', 'Standard MLP', 'Invariant MLP'};

    for m = 1:n_methods
        R2_m = squeeze(mean(R2_curves(m,:,:), 3));
        R2_s = squeeze(std(R2_curves(m,:,:), 0, 3));

        % Shade the error band
        fill([sizes, fliplr(sizes)], ...
             [R2_m + R2_s, fliplr(R2_m - R2_s)], ...
             colors(m,:), 'FaceAlpha', 0.15, 'EdgeColor', 'none');
        hold on;
        plot(sizes, R2_m, ['-', markers{m}], 'Color', colors(m,:), ...
            'MarkerFaceColor', colors(m,:), 'LineWidth', 2, 'MarkerSize', 8, ...
            'DisplayName', display_names{m});
    end

    yline(0, 'k--', 'LineWidth', 1, 'HandleVisibility', 'off');
    xlabel('Number of Molecules', 'FontSize', 13, 'FontWeight', 'bold');
    ylabel('Test R^2', 'FontSize', 13, 'FontWeight', 'bold');
    title('Learning Curves: R^2 vs Sample Size', 'FontSize', 14, 'FontWeight', 'bold');
    legend('Location', 'southeast', 'FontSize', 10);
    set(gca, 'FontSize', 11);
    grid on; box on;

    set(fig, 'InvertHardcopy', 'off');
    exportgraphics(fig, fullfile(opts.results_dir, 'fig3d_learning_curves.pdf'), 'ContentType', 'vector', 'BackgroundColor', 'w');
    saveas(fig, fullfile(opts.results_dir, 'fig3d_learning_curves.png'));
    fprintf('\nFigure saved to %s/fig3d_learning_curves.pdf/.png\n', opts.results_dir);

    % Print summary table
    fprintf('\n,  Learning Curves Summary , \n');
    fprintf('%8s', 'n_mol');
    for m = 1:n_methods, fprintf('  %18s', method_names{m}); end
    fprintf('\n');
    for si = 1:n_sizes
        fprintf('%8d', sizes(si));
        for m = 1:n_methods
            fprintf('  %7.3f +/- %5.3f', mean(R2_curves(m,si,:)), std(R2_curves(m,si,:)));
        end
        fprintf('\n');
    end
end


%% Standalone MLP functions (no class dependency)
function [W, B] = train_mlp_simple(X, y, Xv, yv, layers, maxep, lr)
    nl = length(layers)-1;
    W = cell(nl,1); B = cell(nl,1);
    for l = 1:nl
        fi = layers(l);
        W{l} = randn(fi, layers(l+1)) * sqrt(2/fi);
        B{l} = zeros(1, layers(l+1));
    end
    mW = cellfun(@(w) zeros(size(w)), W, 'Uni', 0); vW = mW;
    mB = cellfun(@(b) zeros(size(b)), B, 'Uni', 0); vB = mB;
    b1=0.9; b2=0.999; ea=1e-8;
    bv=Inf; pa=20; wa=0; Wb=W; Bb=B;
    nn=size(X,1); bs=min(256,nn);
    for ep = 1:maxep
        pm = randperm(nn);
        for s = 1:bs:nn
            bi = pm(s:min(s+bs-1,nn)); Xb=X(bi,:); yb=y(bi);
            A = cell(nl+1,1); A{1}=Xb;
            for l = 1:nl
                Zl = A{l}*W{l}+B{l};
                if l<nl, A{l+1}=max(0,Zl); else, A{l+1}=Zl; end
            end
            dZ = (A{nl+1}-yb)/size(Xb,1);
            for l = nl:-1:1
                gW=A{l}'*dZ; gB=sum(dZ,1);
                if l>1, dZ=(dZ*W{l}').*(A{l}>0); end
                mW{l}=b1*mW{l}+(1-b1)*gW; vW{l}=b2*vW{l}+(1-b2)*gW.^2;
                mB{l}=b1*mB{l}+(1-b1)*gB; vB{l}=b2*vB{l}+(1-b2)*gB.^2;
                W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
                B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
            end
        end
        yp=Xv; for l=1:nl, yp=yp*W{l}+B{l}; if l<nl, yp=max(0,yp); end; end
        vm=mean((yv-yp).^2);
        if vm<bv, bv=vm; wa=0; Wb=W; Bb=B; else, wa=wa+1; if wa>=pa, break; end; end
    end
    W=Wb; B=Bb;
end

function yp = predict_mlp_simple(X, W, B)
    yp = X;
    for l = 1:length(W)
        yp = yp * W{l} + B{l};
        if l < length(W), yp = max(0, yp); end
    end
end