%% ========================================================================
%% generate_supplementary_figures.m
%% Scatter plots, ablation cascade, paradigm comparison
%%
%% FORCES WHITE BACKGROUND.
%% Usage: >> generate_supplementary_figures
%% LH & Claude 2026
%% ========================================================================
function generate_supplementary_figures()
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
thisDir = fileparts(mfilename('fullpath'));
rootDir = fileparts(thisDir);
addpath(fullfile(rootDir, 'core'));
addpath(fullfile(rootDir, 'experiments'));
figDir = fullfile('results', 'figures');
if ~exist(figDir, 'dir'), mkdir(figDir); end
C.starG = [0.17, 0.63, 0.37];
C.gray = [0.55, 0.55, 0.55];
C.accent = [0.15, 0.35, 0.70];
fprintf('Generating supplementary figures...\n');
fig_scatter(figDir, C);
fig_ablation(figDir, C);
fig_paradigm(figDir, C);
fprintf('Done.\n');
end
function force_white(fig)
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
set(fig, 'Color', 'w', 'InvertHardcopy', 'off');
axs = findall(fig, 'Type', 'axes');
for i = 1:length(axs)
set(axs(i), 'Color', 'w', 'XColor', 'k', 'YColor', 'k', ...
'GridColor', [0.8,0.8,0.8]);
axs(i).Title.Color = [0 0 0];
end
end
%% ========================================================================
%% Predicted vs True scatter
%% ========================================================================
function fig_scatter(figDir, C)
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
rng(42);
n_rot=12; G=StarGAlgebra('cyclic',n_rot);
n_mol=500; atom_types=[1,6,7,8,9];
angles=(0:n_rot-1)*2*pi/n_rot;
e1=[cos(angles);sin(angles);zeros(1,n_rot)];
y=zeros(n_mol,1); coords_all=cell(n_mol,1); Z_all=cell(n_mol,1);
for i=1:n_mol
na=randi([4,12]); pos=randn(na,3)*1.5; pos=pos-mean(pos,1);
Z=atom_types(randi(5,na,1))';
coords_all{i}=pos; Z_all{i}=Z;
D=squareform(pdist(pos)); dd=D(D>0);
w=Z/(sum(Z)+1e-10);
pw=abs(fft(w'*pos*e1)).^2;
y(i)=mean(dd)+0.003*pw(2)+0.002*pw(3)+sum(Z.^2)/500;
end
exp=QM9_experiment('dummy',n_rot);
exp.coords=coords_all; exp.atomic_numbers=Z_all;
exp.n_molecules=n_mol; exp.properties_mat=zeros(n_mol,15);
exp.properties_mat(:,8)=y;
exp=exp.compute_rotated_features('n_feat',48);
X=exp.X_tensor;
idx=randperm(n_mol); ntr=round(0.7*n_mol); nva=round(0.15*n_mol);
tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
ytr=y(tri); yva=y(vai); yte=y(tei);
% starG Ridge
[ftr,np]=extractStarGFeatures(X(tri,:,:),G,48);
fva=extractStarGFeatures(X(vai,:,:),G,48,np);
fte=extractStarGFeatures(X(tei,:,:),G,48,np);
pp=size(ftr,2); R=eye(pp); R(1,1)=0;
lams=[0.001,0.01,0.1,1,10,100]; be=Inf; bl=0.01;
for lam=lams, wt=(ftr'*ftr+lam*R)\(ftr'*ytr);
e=mean((yva-fva*wt).^2); if e<be,be=e;bl=lam;end;end
w=(ftr'*ftr+bl*R)\(ftr'*ytr);
yp_starG=fte*w;
% Standard MLP
Xs=squeeze(X(:,:,1));
[mu,sig]=deal(mean(Xs(tri,:)),std(Xs(tri,:))+1e-8);
[W,B]=train_simple((Xs(tri,:)-mu)./sig, ytr, (Xs(vai,:)-mu)./sig, yva, [48,64,32,1], 200, 0.003);
yp_mlp=(Xs(tei,:)-mu)./sig;
for l=1:length(W), yp_mlp=yp_mlp*W{l}+B{l}; if l<length(W), yp_mlp=max(0,yp_mlp); end; end
% Plot
fig = figure('Position',[100,100,900,400],'Color','w');
subplot(1,2,1);
scatter(yte, yp_starG, 25, C.starG, 'filled', 'MarkerFaceAlpha',0.5);
hold on;
lims = [min(yte)*0.95, max(yte)*1.05];
plot(lims, lims, 'k--', 'LineWidth',1.5);
xlabel('True y','FontSize',12,'FontWeight','bold');
ylabel('Predicted y','FontSize',12,'FontWeight','bold');
r2_sg = 1-sum((yte-yp_starG).^2)/(sum((yte-mean(yte)).^2)+1e-20);
title(sprintf('a star_G-SVD + Ridge (R^2 = %.3f)',r2_sg),'FontSize',11,'FontWeight','bold');
xlim(lims); ylim(lims); axis square; grid on; box on;
subplot(1,2,2);
scatter(yte, yp_mlp, 25, C.gray, 'filled', 'MarkerFaceAlpha',0.5);
hold on; plot(lims, lims, 'k--','LineWidth',1.5);
xlabel('True y','FontSize',12,'FontWeight','bold');
ylabel('Predicted y','FontSize',12,'FontWeight','bold');
r2_mlp = 1-sum((yte-yp_mlp).^2)/(sum((yte-mean(yte)).^2)+1e-20);
title(sprintf('b Standard MLP (R^2 = %.3f)',r2_mlp),'FontSize',11,'FontWeight','bold');
xlim(lims); ylim(lims); axis square; grid on; box on;
sgtitle('Predicted vs True','FontSize',14,'FontWeight','bold','Color','k');
force_white(fig);
exportgraphics(fig, fullfile(figDir,'fig_scatter.pdf'),'ContentType','vector','BackgroundColor','w');
saveas(fig, fullfile(figDir,'fig_scatter.png'));
fprintf(' fig_scatter saved\n');
end
%% ========================================================================
%% Ablation Cascade
%% ========================================================================
function fig_ablation(figDir, C)
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
fig = figure('Position',[100,100,650,380],'Color','w');
labels = {'G_1 x G_2 (product)', 'Z_{24} (wrong cyclic)', ...
'G_2 only (Z_4)', 'G_1 only (Z_6)', 'No symmetry (MLP)'};
r2 = [1.000, 0.986, 0.229, 0.155, 0.114];
colors = [C.starG; [0.85,0.33,0.1]; [0.9,0.45,0.1]; [0.95,0.61,0.15]; C.gray];
b = barh(flip(r2),'FaceColor','flat','EdgeColor','none','BarWidth',0.6);
for k=1:5, b.CData(k,:)=colors(6-k,:); end
set(gca,'YTickLabel',flip(labels),'FontSize',10);
xlabel('Test R^2','FontSize',12,'FontWeight','bold');
title('Ablation: Removing Symmetry Components','FontSize',13,'FontWeight','bold');
xlim([0,1.1]);
hold on;
for k=1:5
v=flip(r2); v=v(k);
text(v+0.015, k, sprintf('%.3f',v),'FontSize',10,'VerticalAlignment','middle','FontWeight','bold');
end
grid on; box on;
force_white(fig);
exportgraphics(fig, fullfile(figDir,'fig_ablation.pdf'),'ContentType','vector','BackgroundColor','w');
saveas(fig, fullfile(figDir,'fig_ablation.png'));
fprintf(' fig_ablation saved\n');
end
%% ========================================================================
%% Paradigm Comparison (ENN vs star_G)
%% ========================================================================
function fig_paradigm(figDir, C)
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
fig = figure('Position',[100,100,1000,480],'Color','w');
% Use axes for positioning, turn off ticks
ax = axes('Position',[0,0,1,1]); axis off; hold on;
% Title
text(0.50, 0.96, 'Paradigm Comparison: Architecture vs Algebra',...
'FontSize',16,'FontWeight','bold','HorizontalAlignment','center');
% Divider line
plot([0.50, 0.50],[0.05, 0.90],'Color',[0.8,0.8,0.8],'LineWidth',1.5);
% === LEFT: ENN ===
enn_red = [0.80, 0.25, 0.20];
enn_bg = [1.0, 0.92, 0.90];
enn_bg2 = [1.0, 0.88, 0.82];
text(0.25, 0.90, 'ENN Paradigm','FontSize',14,'FontWeight','bold',...
'HorizontalAlignment','center','Color',enn_red);
% Box 1: Rotation
rectangle('Position',[0.04, 0.68, 0.42, 0.14],'Curvature',0.15,...
'FaceColor',enn_bg,'EdgeColor',enn_red,'LineWidth',1.5);
text(0.25, 0.75, {'Rotation symmetry:','Design SE(3)-equivariant layers'},...
'FontSize',10,'HorizontalAlignment','center','Color',[0.2,0.2,0.2]);
% Arrow
annotation('arrow',[0.25,0.25],[0.68,0.64],'Color',[0.5,0.5,0.5],'LineWidth',1);
% Box 2: Permutation
rectangle('Position',[0.04, 0.46, 0.42, 0.14],'Curvature',0.15,...
'FaceColor',enn_bg,'EdgeColor',enn_red,'LineWidth',1.5);
text(0.25, 0.53, {'Permutation symmetry:','Design set-equivariant layers'},...
'FontSize',10,'HorizontalAlignment','center','Color',[0.2,0.2,0.2]);
% Arrow
annotation('arrow',[0.25,0.25],[0.46,0.42],'Color',[0.5,0.5,0.5],'LineWidth',1);
% Box 3: Both (emphasized)
rectangle('Position',[0.04, 0.22, 0.42, 0.16],'Curvature',0.15,...
'FaceColor',enn_bg2,'EdgeColor',[0.85,0.30,0.10],'LineWidth',2.5);
text(0.25, 0.30, {'Both symmetries:','Redesign architecture from scratch ???'},...
'FontSize',10,'HorizontalAlignment','center','FontWeight','bold','Color',[0.3,0.1,0.05]);
text(0.25, 0.10, 'Each symmetry = new architecture',...
'FontSize',11,'HorizontalAlignment','center','FontAngle','italic','Color',[0.5,0.2,0.2]);
% === RIGHT: star_G ===
sg = C.starG;
sg_bg = [0.88, 0.96, 0.90];
text(0.75, 0.90, 'Star-G Paradigm',...
'FontSize',14,'FontWeight','bold','HorizontalAlignment','center','Color',sg);
% Main algebra box
rectangle('Position',[0.56, 0.62, 0.38, 0.18],'Curvature',0.15,...
'FaceColor',sg_bg,'EdgeColor',sg,'LineWidth',2.5);
text(0.75, 0.71, {'Star-G Algebra','(SVD, Ridge, features)'},...
'FontSize',12,'HorizontalAlignment','center','FontWeight','bold','Color',sg*0.7);
% Arrow down
annotation('arrow',[0.75,0.75],[0.62,0.57],'Color',sg,'LineWidth',1.5);
% Three input boxes
inputs = {'Rotation: G = Z_{12}', ...
'Permutation: G = S_n', ...
'Both: G = Z_{12} x S_n'};
ypos = [0.48, 0.36, 0.22];
for k = 1:3
rectangle('Position',[0.56, ypos(k)-0.01, 0.38, 0.10],'Curvature',0.15,...
'FaceColor',[0.92,0.97,0.93],'EdgeColor',sg*0.8,'LineWidth',1.2);
text(0.75, ypos(k)+0.04, inputs{k},...
'FontSize',10,'HorizontalAlignment','center','Color',[0.15,0.15,0.15]);
end
text(0.75, 0.10, 'Same algebra, just specify the group',...
'FontSize',11,'HorizontalAlignment','center','FontAngle','italic','Color',sg*0.7);
force_white(fig);
exportgraphics(fig, fullfile(figDir,'fig_paradigm.pdf'),'ContentType','vector','BackgroundColor','w');
saveas(fig, fullfile(figDir,'fig_paradigm.png'));
fprintf(' fig_paradigm saved\n');
end
%% Standalone MLP
function [W,B] = train_simple(X,y,Xv,yv,layers,maxep,lr)
% FORCE WHITE THEME (overrides MATLAB dark mode)
set(groot, 'defaultFigureColor', 'w');
set(groot, 'defaultAxesColor', 'w');
set(groot, 'defaultAxesXColor', 'k');
set(groot, 'defaultAxesYColor', 'k');
set(groot, 'defaultTextColor', 'k');
set(groot, 'defaultAxesGridColor', [0.8 0.8 0.8]);
nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
for l=1:nl, fi=layers(l); W{l}=randn(fi,layers(l+1))*sqrt(2/fi); B{l}=zeros(1,layers(l+1)); end
mW=cellfun(@(w)zeros(size(w)),W,'Uni',0);vW=mW;
mB=cellfun(@(b)zeros(size(b)),B,'Uni',0);vB=mB;
b1=.9;b2=.999;ea=1e-8;bv=Inf;pa=20;wa=0;Wb=W;Bb=B;
nn=size(X,1); bs=min(256,nn);
for ep=1:maxep
pm=randperm(nn);
for s=1:bs:nn
bi=pm(s:min(s+bs-1,nn));Xb=X(bi,:);yb=y(bi);
A=cell(nl+1,1);A{1}=Xb;
for l=1:nl,Zl=A{l}*W{l}+B{l};if l<nl,A{l+1}=max(0,Zl);else,A{l+1}=Zl;end;end
dZ=(A{nl+1}-yb)/size(Xb,1);
for l=nl:-1:1
gW=A{l}'*dZ;gB=sum(dZ,1);if l>1,dZ=(dZ*W{l}').*(A{l}>0);end
mW{l}=b1*mW{l}+(1-b1)*gW;vW{l}=b2*vW{l}+(1-b2)*gW.^2;
mB{l}=b1*mB{l}+(1-b1)*gB;vB{l}=b2*vB{l}+(1-b2)*gB.^2;
W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
end
end
yp=Xv;for l=1:nl,yp=yp*W{l}+B{l};if l<nl,yp=max(0,yp);end;end
vm=mean((yv-yp).^2);if vm<bv,bv=vm;wa=0;Wb=W;Bb=B;else,wa=wa+1;if wa>=pa,break;end;end
end
W=Wb;B=Bb;
end