tensor-group-sym / experiments / generate_all_figures.m
generate_all_figures.m
Raw
%% ========================================================================
%% generate_all_figures.m
%% Publication-quality figures for the Nature paper
%%
%% FORCES WHITE BACKGROUND regardless of MATLAB theme.
%%
%% Usage: >> generate_all_figures
%% LH & Claude 2026
%% ========================================================================

function generate_all_figures(~)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);


    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    figDir = fullfile('results', 'figures');
    if ~exist(figDir, 'dir'), mkdir(figDir); end

    % Color palette
    C.starG      = [0.17, 0.63, 0.37];
    C.starG_lite = [0.40, 0.78, 0.55];
    C.factor1    = [0.95, 0.61, 0.15];
    C.factor2    = [0.90, 0.45, 0.10];
    C.wrong      = [0.85, 0.33, 0.10];
    C.base1      = [0.55, 0.55, 0.55];
    C.base2      = [0.70, 0.70, 0.70];
    C.base3      = [0.40, 0.40, 0.40];
    C.accent     = [0.15, 0.35, 0.70];

    fprintf('Generating figures in %s/\n\n', figDir);
    fig2_synthetic(figDir, C);
    fig3_qm9(figDir, C);
    fig4_product(figDir, C);
    fig5_discovery(figDir, C);
    fprintf('\nDone. All figures in %s/\n', figDir);
end

%% Force white theme on all axes in a figure
function force_white(fig)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    set(fig, 'Color', 'w', 'InvertHardcopy', 'off');
    axs = findall(fig, 'Type', 'axes');
    for i = 1:length(axs)
        set(axs(i), 'Color', 'w', 'XColor', 'k', 'YColor', 'k', ...
            'GridColor', [0.8, 0.8, 0.8], 'MinorGridColor', [0.9, 0.9, 0.9]);
        axs(i).Title.Color = [0 0 0];  % force title black
    end
    txts = findall(fig, 'Type', 'text');
    for i = 1:length(txts)
        c = get(txts(i), 'Color'); if all(c > 0.85)  % near-white text -> make black
            set(txts(i), 'Color', 'k');
        end
    end
end

%% ========================================================================
%% FIGURE 2: Synthetic Validation
%% ========================================================================
function fig2_synthetic(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    methods = {'star_G-SVD + Ridge', 'Augmented MLP', 'Neural star_G', 'Standard MLP', 'Invariant MLP'};
    R2m = [1.000, 0.998, 0.697, 0.377, 0.327];
    R2s = [0.000, 0.000, 0.063, 0.054, 0.147];
    rv  = [5.8e-31, 3.7e-5, 1.0e-30, 1.4e-1, 1e-33];
    par = [101, 5249, 8641, 5249, 14465];
    col = [C.starG; C.base2; C.starG_lite; C.base1; C.base3];

    fig = figure('Position', [50,50,1400,900], 'Color', 'w');

    % Panel A: R2
    ax1 = subplot(2,2,1);
    b = bar(R2m, 'FaceColor','flat','EdgeColor','none','BarWidth',0.7);
    for k=1:5, b.CData(k,:)=col(k,:); end
    hold on; errorbar(1:5, R2m, R2s, 'k.','LineWidth',1.5,'CapSize',8);
    set(gca,'XTickLabel',methods,'FontSize',9); xtickangle(25);
    ylabel('Test R^2','FontSize',12,'FontWeight','bold');
    ylim([0,1.08]); grid on; box on;
    title('a','FontSize',14,'FontWeight','bold','Units','normalized','Position',[-0.08,1.03]);
    text(1, 1.04, 'R^2 = 1.000','HorizontalAlignment','center','FontSize',9,...
        'FontWeight','bold','Color',C.starG);

    % Panel B: Rotation variance
    ax2 = subplot(2,2,2);
    rv_log = log10(rv + 1e-33);
    b2 = bar(rv_log,'FaceColor','flat','EdgeColor','none','BarWidth',0.7);
    for k=1:5, b2.CData(k,:)=col(k,:); end
    set(gca,'XTickLabel',methods,'FontSize',9); xtickangle(25);
    ylabel('log_{10}(Rotation Variance)','FontSize',12,'FontWeight','bold');
    title('b','FontSize',14,'FontWeight','bold','Units','normalized','Position',[-0.08,1.03]);
    grid on; box on;
    % Arrow annotation for 30 orders gap
    hold on;
    yl = ylim;
    text(3, -15, {'30 orders','of magnitude'},'HorizontalAlignment','center',...
        'FontSize',10,'Color',C.accent,'FontWeight','bold');

    % Panel C: Param efficiency
    ax3 = subplot(2,2,3);
    for k=1:5
        scatter(par(k), R2m(k), 140, col(k,:),'filled','MarkerEdgeColor','k','LineWidth',0.5);
        hold on;
        text(par(k)*1.15, R2m(k), methods{k},'FontSize',8,'Color',col(k,:)*0.7);
    end
    set(gca,'XScale','log','FontSize',10);
    xlabel('Parameters','FontSize',12,'FontWeight','bold');
    ylabel('Test R^2','FontSize',12,'FontWeight','bold');
    xlim([60,20000]); ylim([0,1.08]);
    title('c','FontSize',14,'FontWeight','bold','Units','normalized','Position',[-0.08,1.03]);
    % Pareto line
    plot([101,5249],[1.0,0.998],'k--','LineWidth',0.8);
    text(250, 0.96, 'Pareto frontier','FontSize',8,'FontAngle','italic','Color',[0.4,0.4,0.4]);
    grid on; box on;

    % Panel D: Summary radar-like horizontal bars
    ax4 = subplot(2,2,4);
    % Three metrics normalized to [0,1]
    m_r2 = R2m;  % already 0-1
    m_inv = 1 - (log10(rv+1e-33) - min(log10(rv+1e-33))) / (max(log10(rv+1e-33)) - min(log10(rv+1e-33)));
    m_eff = 1 - log10(par)/max(log10(par));
    data = [m_r2; m_inv; m_eff]';  % 5 x 3

    bh = barh(data, 'grouped', 'EdgeColor','none','BarWidth',0.85);
    cmap = [C.starG; C.accent; C.starG_lite];
    for k=1:3, bh(k).FaceColor = cmap(k,:); bh(k).FaceAlpha = 0.8; end
    set(gca,'YTickLabel',methods,'FontSize',9,'YDir','reverse');
    xlabel('Normalized Score','FontSize',11);
    legend({'Prediction (R^2)','Invariance','Efficiency (1/params)'},...
        'Location','southeast','FontSize',8);
    title('d','FontSize',14,'FontWeight','bold','Units','normalized','Position',[-0.08,1.03]);
    grid on; box on;

    sgtitle('Synthetic Validation (Z_{12}, 1,000 molecules)','FontSize',14,'FontWeight','bold','Color','k');
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig2_synthetic.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig2_synthetic.png'));
    fprintf('  fig2_synthetic saved\n');
end

%% ========================================================================
%% FIGURE 3: QM9 Real Data
%% ========================================================================
function fig3_qm9(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    methods = {'star_G-SVD + Ridge','Augmented MLP','Standard MLP','Invariant MLP','Neural star_G'};
    R2m  = [0.556, 0.384, -10.99, -3.85, -5.15];
    R2s  = [0.047, 0.028, 5.90, 1.58, 3.10];
    RMSE = [0.035, 0.042, 0.179, 0.116, 0.127];
    RMSEs= [0.000, 0.002, 0.041, 0.024, 0.032];
    rv   = [3.2e-31, 3.9e-4, 6.7e-2, 1e-33, 2.4e-28];
    col  = [C.starG; C.base2; C.base1; C.base3; C.starG_lite];

    fig = figure('Position', [50,50,1500,420], 'Color','w');

    % Panel A: R2
    subplot(1,3,1);
    b = bar(R2m,'FaceColor','flat','EdgeColor','none','BarWidth',0.7);
    for k=1:5, b.CData(k,:)=col(k,:); end
    hold on; errorbar(1:5, R2m, R2s, 'k.','LineWidth',1.5,'CapSize',6);
    yline(0,'k-','LineWidth',1.2);
    set(gca,'XTickLabel',methods,'FontSize',8); xtickangle(30);
    ylabel('Test R^2','FontSize',12,'FontWeight','bold');
    ylim([min(R2m)-2, 1.0]);
    title('a   Test R^2','FontSize',12,'FontWeight','bold');
    text(1, 0.75, sprintf('R^2 = %.3f',R2m(1)),'HorizontalAlignment','center',...
        'FontSize',9,'FontWeight','bold','Color',C.starG);
    text(3.5, min(R2m)+0.5, 'MLPs overfit','HorizontalAlignment','center',...
        'FontSize',10,'Color',[0.7,0.15,0.15],'FontAngle','italic');
    grid on; box on;

    % Panel B: RMSE
    subplot(1,3,2);
    b2 = bar(RMSE,'FaceColor','flat','EdgeColor','none','BarWidth',0.7);
    for k=1:5, b2.CData(k,:)=col(k,:); end
    hold on; errorbar(1:5, RMSE, RMSEs, 'k.','LineWidth',1.5,'CapSize',6);
    set(gca,'XTickLabel',methods,'FontSize',8); xtickangle(30);
    ylabel('RMSE (Hartree)','FontSize',12,'FontWeight','bold');
    title('b   RMSE (lower is better)','FontSize',12,'FontWeight','bold');
    grid on; box on;

    % Panel C: Rotation variance
    subplot(1,3,3);
    rv_log = log10(rv+1e-33);
    b3 = bar(rv_log,'FaceColor','flat','EdgeColor','none','BarWidth',0.7);
    for k=1:5, b3.CData(k,:)=col(k,:); end
    set(gca,'XTickLabel',methods,'FontSize',8); xtickangle(30);
    ylabel('log_{10}(Rotation Variance)','FontSize',12,'FontWeight','bold');
    title('c   Invariance Quality','FontSize',12,'FontWeight','bold');
    grid on; box on;

    sgtitle('QM9 HOMO-LUMO Gap (1,000 real molecules)','FontSize',14,'FontWeight','bold','Color','k');
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig3_qm9.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig3_qm9.png'));
    fprintf('  fig3_qm9 saved\n');
end

%% ========================================================================
%% FIGURE 4: Product Group (Key Figure)
%% ========================================================================
function fig4_product(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    methods = {'G_1 x G_2 Ridge','Z_{24} cyclic','G_1 x G_2 MLP','Std MLP',...
               'Inv. MLP','Aug. MLP','G_2 only (Z_4)','G_1 only (Z_6)'};
    R2m = [1.000, 0.986, 0.826, 0.488, 0.324, 0.250, 0.229, 0.155];
    R2s = [0.000, 0.002, 0.099, 0.116, 0.267, 0.239, 0.191, 0.244];
    par = [186, 157, 9473, 4481, 6785, 4481, 42, 55];
    col = [C.starG; C.wrong; C.starG_lite; C.base1; C.base3; C.base2; C.factor2; C.factor1];

    fig = figure('Position', [50,50,1600,520], 'Color','w');

    % Panel A: Bar chart
    ax1 = subplot(1,3,[1,2]);
    b = bar(R2m,'FaceColor','flat','EdgeColor','none','BarWidth',0.75);
    for k=1:8, b.CData(k,:)=col(k,:); end
    hold on; errorbar(1:8, R2m, R2s, 'k.','LineWidth',1.5,'CapSize',5);
    set(gca,'XTickLabel',methods,'FontSize',9); xtickangle(30);
    ylabel('Test R^2','FontSize',13,'FontWeight','bold');
    ylim([0,1.08]);
    title('a   Method Comparison','FontSize',13,'FontWeight','bold');
    text(1, 1.04, 'R^2 = 1.000','HorizontalAlignment','center',...
        'FontSize',10,'FontWeight','bold','Color',C.starG);

    % Bracket for single factors
    plot([7,8],[0.35,0.35],'Color',[0.4,0.4,0.4],'LineWidth',1.2);
    plot([7,7],[0.32,0.35],'Color',[0.4,0.4,0.4],'LineWidth',1.2);
    plot([8,8],[0.32,0.35],'Color',[0.4,0.4,0.4],'LineWidth',1.2);
    text(7.5, 0.40, 'Single factors','HorizontalAlignment','center',...
        'FontSize',9,'FontAngle','italic','Color',C.factor1);

    % Param counts below x-axis
    for k=1:8
        text(k, -0.04, sprintf('%d',par(k)),'HorizontalAlignment','center',...
            'FontSize',7,'Color',[0.45,0.45,0.45]);
    end
    text(0.15, -0.04, 'Params:','FontSize',7,'Color',[0.45,0.45,0.45]);
    grid on; box on;

    % Panel B: 2D frequency heatmap
    ax2 = subplot(1,3,3);
    n1=6; n2=4;
    fmap = zeros(n1, n2);
    fmap(2,1)=1.0; fmap(1,2)=1.0;  % single-axis
    fmap(2,2)=5.0; fmap(2,3)=3.0; fmap(3,2)=3.0;  % coupled

    % Custom green colormap
    nc = 64;
    g_cmap = [ones(1,nc)', ones(1,nc)', ones(1,nc)';  % white at 0
              linspace(0.92,0.12,nc)', linspace(0.96,0.55,nc)', linspace(0.92,0.30,nc)'];
    imagesc(0:n2-1, 0:n1-1, fmap);
    colormap(ax2, g_cmap);
    cb = colorbar; cb.Label.String = 'Target Coefficient';
    cb.Label.FontSize = 10;
    set(gca,'YDir','normal','FontSize',10);
    xlabel('Axial frequency f_2','FontSize',11,'FontWeight','bold');
    ylabel('Angular frequency f_1','FontSize',11,'FontWeight','bold');
    title('b   2D Frequency Map','FontSize',13,'FontWeight','bold');

    % Red borders on coupled cells
    hold on;
    coupled = [2,2; 2,3; 3,2];  % (f1,f2) in 1-indexed
    for i = 1:size(coupled,1)
        f1 = coupled(i,1)-1; f2 = coupled(i,2)-1;  % 0-indexed for imagesc
        rectangle('Position',[f2-0.5, f1-0.5, 1, 1],'EdgeColor','r','LineWidth',2.5);
    end

    % Value labels
    text(1, 1, '5.0','HorizontalAlignment','center','FontSize',12,'FontWeight','bold','Color','w');
    text(2, 1, '3.0','HorizontalAlignment','center','FontSize',12,'FontWeight','bold','Color','w');
    text(1, 2, '3.0','HorizontalAlignment','center','FontSize',12,'FontWeight','bold','Color','w');
    text(0, 1, '1.0','HorizontalAlignment','center','FontSize',9,'Color',[0.3,0.3,0.3]);
    text(1, 0, '1.0','HorizontalAlignment','center','FontSize',9,'Color',[0.3,0.3,0.3]);

    % Legend text (inside plot area, positioned safely)
    text(2.5, 4.5, 'Coupled (87%)','FontSize',10,'Color','r','FontWeight','bold');
    text(2.5, 3.8, 'Only G_1 x G_2','FontSize',9,'Color',[0.3,0.3,0.3],'FontAngle','italic');
    text(2.5, 3.3, 'resolves these','FontSize',9,'Color',[0.3,0.3,0.3],'FontAngle','italic');

    sgtitle('Product Group Z_6 x Z_4: Compositional Advantage',...
        'FontSize',15,'FontWeight','bold');
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig4_product.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig4_product.png'));
    fprintf('  fig4_product saved\n');
end

%% ========================================================================
%% FIGURE 5: Symmetry Discovery
%% ========================================================================
function fig5_discovery(figDir, C)
    % FORCE WHITE THEME (overrides MATLAB dark mode)
    set(groot, 'defaultFigureColor', 'w');
    set(groot, 'defaultAxesColor', 'w');
    set(groot, 'defaultAxesXColor', 'k');
    set(groot, 'defaultAxesYColor', 'k');
    set(groot, 'defaultTextColor', 'k');
    set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);

    fig = figure('Position',[50,50,1200,420],'Color','w');

    % Panel A: QM9 group discovery
    subplot(1,2,1);
    groups = {'Z_2','V_4','Z_4','Z_{12}','D_6','Z_6','D_3','Z_3'};
    pr2 = [0.479, 0.580, 0.590, 0.534, 0.534, 0.537, 0.537, 0.504];
    combined = [0.640, 0.620, 0.620, 0.612, 0.612, 0.603, 0.603, 0.579];

    [cs, si] = sort(combined,'descend');
    gs = groups(si); ps = pr2(si);

    bh = barh(cs,'FaceColor',[0.3,0.6,0.9],'EdgeColor','none','BarWidth',0.65);
    hold on;
    barh(1, cs(1),'FaceColor',C.starG,'EdgeColor','none','BarWidth',0.65);
    set(gca,'YTickLabel',gs,'FontSize',10,'YDir','reverse');
    xlabel('Combined Score','FontSize',12,'FontWeight','bold');
    title('a   QM9 Group Discovery','FontSize',13,'FontWeight','bold');
    for k=1:length(gs)
        text(cs(k)+0.003, k, sprintf('R^2=%.3f',ps(k)),'FontSize',8,...
            'VerticalAlignment','middle');
    end
    xlim([0.55, 0.70]);
    grid on; box on;

    % Panel B: Factorization discovery
    subplot(1,2,2);
    fn = {'Z_3 x Z_8','Z_{24} (cyclic)','Z_4 x Z_6','Z_2 x Z_{12}'};
    fr2 = [1.000, 0.985, 0.961, 0.495];

    bh2 = barh(fr2,'FaceColor',[0.3,0.6,0.9],'EdgeColor','none','BarWidth',0.55);
    hold on;
    barh(1, fr2(1),'FaceColor',C.starG,'EdgeColor','none','BarWidth',0.55);
    set(gca,'YTickLabel',fn,'FontSize',10,'YDir','reverse');
    xlabel('Prediction R^2','FontSize',12,'FontWeight','bold');
    title('b   Factorization Discovery (n=24)','FontSize',13,'FontWeight','bold');
    xlim([0.3, 1.08]);
    text(fr2(1)+0.01, 1, 'R^2 = 1.000','FontSize',9,'FontWeight','bold',...
        'Color',C.starG,'VerticalAlignment','middle');
    text(0.55, 3.5, {'Algorithm discovers','optimal decomposition'},...
        'FontSize',10,'Color',C.accent,'FontAngle','italic','HorizontalAlignment','center');
    grid on; box on;

    sgtitle('Data-Driven Symmetry Discovery','FontSize',14,'FontWeight','bold','Color','k');
    force_white(fig);
    exportgraphics(fig, fullfile(figDir,'fig5_discovery.pdf'),'ContentType','vector','BackgroundColor','w');
    saveas(fig, fullfile(figDir,'fig5_discovery.png'));
    fprintf('  fig5_discovery saved\n');
end