tensor-group-sym / experiments / generate_supplementary_figures.m
generate_supplementary_figures.m
Raw
%% ========================================================================
%% generate_supplementary_figures.m
%% Scatter plots, ablation cascade, paradigm comparison
%%
%% FORCES WHITE BACKGROUND.
%% Usage: >> generate_supplementary_figures
%% LH & Claude 2026
%% ========================================================================

function generate_supplementary_figures()
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);


    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    figDir = fullfile('results', 'figures');
    if ~exist(figDir, 'dir'), mkdir(figDir); end

    C.starG = [0.17, 0.63, 0.37];
    C.gray  = [0.55, 0.55, 0.55];
    C.accent = [0.15, 0.35, 0.70];

    fprintf('Generating supplementary figures...\n');
    fig_scatter(figDir, C);
    fig_ablation(figDir, C);
    fig_paradigm(figDir, C);
    fprintf('Done.\n');
end

function force_white(fig)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    set(fig, 'Color', 'w', 'InvertHardcopy', 'off');
    axs = findall(fig, 'Type', 'axes');
    for i = 1:length(axs)
        set(axs(i), 'Color', 'w', 'XColor', 'k', 'YColor', 'k', ...
            'GridColor', [0.8,0.8,0.8]);
        axs(i).Title.Color = [0 0 0];
    end
end


%% ========================================================================
%% Predicted vs True scatter
%% ========================================================================
function fig_scatter(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    rng(42);
    n_rot=12; G=StarGAlgebra('cyclic',n_rot);
    n_mol=500; atom_types=[1,6,7,8,9];
    angles=(0:n_rot-1)*2*pi/n_rot;
    e1=[cos(angles);sin(angles);zeros(1,n_rot)];

    y=zeros(n_mol,1); coords_all=cell(n_mol,1); Z_all=cell(n_mol,1);
    for i=1:n_mol
        na=randi([4,12]); pos=randn(na,3)*1.5; pos=pos-mean(pos,1);
        Z=atom_types(randi(5,na,1))';
        coords_all{i}=pos; Z_all{i}=Z;
        D=squareform(pdist(pos)); dd=D(D>0);
        w=Z/(sum(Z)+1e-10);
        pw=abs(fft(w'*pos*e1)).^2;
        y(i)=mean(dd)+0.003*pw(2)+0.002*pw(3)+sum(Z.^2)/500;
    end

    exp=QM9_experiment('dummy',n_rot);
    exp.coords=coords_all; exp.atomic_numbers=Z_all;
    exp.n_molecules=n_mol; exp.properties_mat=zeros(n_mol,15);
    exp.properties_mat(:,8)=y;
    exp=exp.compute_rotated_features('n_feat',48);
    X=exp.X_tensor;

    idx=randperm(n_mol); ntr=round(0.7*n_mol); nva=round(0.15*n_mol);
    tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
    ytr=y(tri); yva=y(vai); yte=y(tei);

    % starG Ridge
    [ftr,np]=extractStarGFeatures(X(tri,:,:),G,48);
    fva=extractStarGFeatures(X(vai,:,:),G,48,np);
    fte=extractStarGFeatures(X(tei,:,:),G,48,np);
    pp=size(ftr,2); R=eye(pp); R(1,1)=0;
    lams=[0.001,0.01,0.1,1,10,100]; be=Inf; bl=0.01;
    for lam=lams, wt=(ftr'*ftr+lam*R)\(ftr'*ytr);
        e=mean((yva-fva*wt).^2); if e<be,be=e;bl=lam;end;end
    w=(ftr'*ftr+bl*R)\(ftr'*ytr);
    yp_starG=fte*w;

    % Standard MLP
    Xs=squeeze(X(:,:,1));
    [mu,sig]=deal(mean(Xs(tri,:)),std(Xs(tri,:))+1e-8);
    [W,B]=train_simple((Xs(tri,:)-mu)./sig, ytr, (Xs(vai,:)-mu)./sig, yva, [48,64,32,1], 200, 0.003);
    yp_mlp=(Xs(tei,:)-mu)./sig;
    for l=1:length(W), yp_mlp=yp_mlp*W{l}+B{l}; if l<length(W), yp_mlp=max(0,yp_mlp); end; end

    % Plot
    fig = figure('Position',[100,100,900,400],'Color','w');

    subplot(1,2,1);
    scatter(yte, yp_starG, 25, C.starG, 'filled', 'MarkerFaceAlpha',0.5);
    hold on;
    lims = [min(yte)*0.95, max(yte)*1.05];
    plot(lims, lims, 'k--', 'LineWidth',1.5);
    xlabel('True y','FontSize',12,'FontWeight','bold');
    ylabel('Predicted y','FontSize',12,'FontWeight','bold');
    r2_sg = 1-sum((yte-yp_starG).^2)/(sum((yte-mean(yte)).^2)+1e-20);
    title(sprintf('a   star_G-SVD + Ridge (R^2 = %.3f)',r2_sg),'FontSize',11,'FontWeight','bold');
    xlim(lims); ylim(lims); axis square; grid on; box on;

    subplot(1,2,2);
    scatter(yte, yp_mlp, 25, C.gray, 'filled', 'MarkerFaceAlpha',0.5);
    hold on; plot(lims, lims, 'k--','LineWidth',1.5);
    xlabel('True y','FontSize',12,'FontWeight','bold');
    ylabel('Predicted y','FontSize',12,'FontWeight','bold');
    r2_mlp = 1-sum((yte-yp_mlp).^2)/(sum((yte-mean(yte)).^2)+1e-20);
    title(sprintf('b   Standard MLP (R^2 = %.3f)',r2_mlp),'FontSize',11,'FontWeight','bold');
    xlim(lims); ylim(lims); axis square; grid on; box on;

    sgtitle('Predicted vs True','FontSize',14,'FontWeight','bold','Color','k');
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig_scatter.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig_scatter.png'));
    fprintf('  fig_scatter saved\n');
end


%% ========================================================================
%% Ablation Cascade
%% ========================================================================
function fig_ablation(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    fig = figure('Position',[100,100,650,380],'Color','w');

    labels = {'G_1 x G_2 (product)', 'Z_{24} (wrong cyclic)', ...
              'G_2 only (Z_4)', 'G_1 only (Z_6)', 'No symmetry (MLP)'};
    r2 = [1.000, 0.986, 0.229, 0.155, 0.114];
    colors = [C.starG; [0.85,0.33,0.1]; [0.9,0.45,0.1]; [0.95,0.61,0.15]; C.gray];

    b = barh(flip(r2),'FaceColor','flat','EdgeColor','none','BarWidth',0.6);
    for k=1:5, b.CData(k,:)=colors(6-k,:); end
    set(gca,'YTickLabel',flip(labels),'FontSize',10);
    xlabel('Test R^2','FontSize',12,'FontWeight','bold');
    title('Ablation: Removing Symmetry Components','FontSize',13,'FontWeight','bold');
    xlim([0,1.1]);
    hold on;
    for k=1:5
        v=flip(r2); v=v(k);
        text(v+0.015, k, sprintf('%.3f',v),'FontSize',10,'VerticalAlignment','middle','FontWeight','bold');
    end
    grid on; box on;
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig_ablation.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig_ablation.png'));
    fprintf('  fig_ablation saved\n');
end


%% ========================================================================
%% Paradigm Comparison (ENN vs star_G)
%% ========================================================================
function fig_paradigm(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    fig = figure('Position',[100,100,1000,480],'Color','w');

    % Use axes for positioning, turn off ticks
    ax = axes('Position',[0,0,1,1]); axis off; hold on;

    % Title
    text(0.50, 0.96, 'Paradigm Comparison: Architecture vs Algebra',...
        'FontSize',16,'FontWeight','bold','HorizontalAlignment','center');

    % Divider line
    plot([0.50, 0.50],[0.05, 0.90],'Color',[0.8,0.8,0.8],'LineWidth',1.5);

    % === LEFT: ENN ===
    enn_red = [0.80, 0.25, 0.20];
    enn_bg = [1.0, 0.92, 0.90];
    enn_bg2 = [1.0, 0.88, 0.82];

    text(0.25, 0.90, 'ENN Paradigm','FontSize',14,'FontWeight','bold',...
        'HorizontalAlignment','center','Color',enn_red);

    % Box 1: Rotation
    rectangle('Position',[0.04, 0.68, 0.42, 0.14],'Curvature',0.15,...
        'FaceColor',enn_bg,'EdgeColor',enn_red,'LineWidth',1.5);
    text(0.25, 0.75, {'Rotation symmetry:','Design SE(3)-equivariant layers'},...
        'FontSize',10,'HorizontalAlignment','center','Color',[0.2,0.2,0.2]);

    % Arrow
    annotation('arrow',[0.25,0.25],[0.68,0.64],'Color',[0.5,0.5,0.5],'LineWidth',1);

    % Box 2: Permutation
    rectangle('Position',[0.04, 0.46, 0.42, 0.14],'Curvature',0.15,...
        'FaceColor',enn_bg,'EdgeColor',enn_red,'LineWidth',1.5);
    text(0.25, 0.53, {'Permutation symmetry:','Design set-equivariant layers'},...
        'FontSize',10,'HorizontalAlignment','center','Color',[0.2,0.2,0.2]);

    % Arrow
    annotation('arrow',[0.25,0.25],[0.46,0.42],'Color',[0.5,0.5,0.5],'LineWidth',1);

    % Box 3: Both (emphasized)
    rectangle('Position',[0.04, 0.22, 0.42, 0.16],'Curvature',0.15,...
        'FaceColor',enn_bg2,'EdgeColor',[0.85,0.30,0.10],'LineWidth',2.5);
    text(0.25, 0.30, {'Both symmetries:','Redesign architecture from scratch ???'},...
        'FontSize',10,'HorizontalAlignment','center','FontWeight','bold','Color',[0.3,0.1,0.05]);

    text(0.25, 0.10, 'Each symmetry = new architecture',...
        'FontSize',11,'HorizontalAlignment','center','FontAngle','italic','Color',[0.5,0.2,0.2]);

    % === RIGHT: star_G ===
    sg = C.starG;
    sg_bg = [0.88, 0.96, 0.90];

    text(0.75, 0.90, 'Star-G Paradigm',...
        'FontSize',14,'FontWeight','bold','HorizontalAlignment','center','Color',sg);

    % Main algebra box
    rectangle('Position',[0.56, 0.62, 0.38, 0.18],'Curvature',0.15,...
        'FaceColor',sg_bg,'EdgeColor',sg,'LineWidth',2.5);
    text(0.75, 0.71, {'Star-G Algebra','(SVD, Ridge, features)'},...
        'FontSize',12,'HorizontalAlignment','center','FontWeight','bold','Color',sg*0.7);

    % Arrow down
    annotation('arrow',[0.75,0.75],[0.62,0.57],'Color',sg,'LineWidth',1.5);

    % Three input boxes
    inputs = {'Rotation: G = Z_{12}', ...
              'Permutation: G = S_n', ...
              'Both: G = Z_{12} x S_n'};
    ypos = [0.48, 0.36, 0.22];
    for k = 1:3
        rectangle('Position',[0.56, ypos(k)-0.01, 0.38, 0.10],'Curvature',0.15,...
            'FaceColor',[0.92,0.97,0.93],'EdgeColor',sg*0.8,'LineWidth',1.2);
        text(0.75, ypos(k)+0.04, inputs{k},...
            'FontSize',10,'HorizontalAlignment','center','Color',[0.15,0.15,0.15]);
    end

    text(0.75, 0.10, 'Same algebra, just specify the group',...
        'FontSize',11,'HorizontalAlignment','center','FontAngle','italic','Color',sg*0.7);

    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig_paradigm.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig_paradigm.png'));
    fprintf('  fig_paradigm saved\n');
end


%% Standalone MLP
function [W,B] = train_simple(X,y,Xv,yv,layers,maxep,lr)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
    for l=1:nl, fi=layers(l); W{l}=randn(fi,layers(l+1))*sqrt(2/fi); B{l}=zeros(1,layers(l+1)); end
    mW=cellfun(@(w)zeros(size(w)),W,'Uni',0);vW=mW;
    mB=cellfun(@(b)zeros(size(b)),B,'Uni',0);vB=mB;
    b1=.9;b2=.999;ea=1e-8;bv=Inf;pa=20;wa=0;Wb=W;Bb=B;
    nn=size(X,1); bs=min(256,nn);
    for ep=1:maxep
        pm=randperm(nn);
        for s=1:bs:nn
            bi=pm(s:min(s+bs-1,nn));Xb=X(bi,:);yb=y(bi);
            A=cell(nl+1,1);A{1}=Xb;
            for l=1:nl,Zl=A{l}*W{l}+B{l};if l<nl,A{l+1}=max(0,Zl);else,A{l+1}=Zl;end;end
            dZ=(A{nl+1}-yb)/size(Xb,1);
            for l=nl:-1:1
                gW=A{l}'*dZ;gB=sum(dZ,1);if l>1,dZ=(dZ*W{l}').*(A{l}>0);end
                mW{l}=b1*mW{l}+(1-b1)*gW;vW{l}=b2*vW{l}+(1-b2)*gW.^2;
                mB{l}=b1*mB{l}+(1-b1)*gB;vB{l}=b2*vB{l}+(1-b2)*gB.^2;
                W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
                B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
            end
        end
        yp=Xv;for l=1:nl,yp=yp*W{l}+B{l};if l<nl,yp=max(0,yp);end;end
        vm=mean((yv-yp).^2);if vm<bv,bv=vm;wa=0;Wb=W;Bb=B;else,wa=wa+1;if wa>=pa,break;end;end
    end
    W=Wb;B=Bb;
end