%% ========================================================================
%% QM9_experiment.m
%% Real-world validation on QM9 molecular property prediction
%%
%% LAPTOP-FRIENDLY: default 1000 molecules, timing printed per step.
%% LH & Claude 2026
%% ========================================================================
classdef QM9_experiment < handle
properties
data_dir; n_molecules; n_rotations; max_atoms
coords; atomic_numbers; properties_mat
G; X_tensor; n_feat_per_rot; results
is_real_data % true if loaded from .xyz files
end
methods
function obj = QM9_experiment(data_dir, n_rotations, varargin)
p = inputParser;
addRequired(p,'data_dir'); addRequired(p,'n_rotations');
addParameter(p,'max_atoms',29); addParameter(p,'n_molecules',Inf);
parse(p, data_dir, n_rotations, varargin{:});
obj.data_dir = p.Results.data_dir;
obj.n_rotations = p.Results.n_rotations;
obj.max_atoms = p.Results.max_atoms;
obj.n_molecules = 0;
obj.is_real_data = false;
obj.G = StarGAlgebra('cyclic', obj.n_rotations);
fprintf('QM9 Experiment: |G| = %d (Z_%d)\n', obj.n_rotations, obj.n_rotations);
end
%% DATA LOADING ===================================================
function obj = load_data(obj, n_max)
if nargin < 2, n_max = Inf; end
if isfolder(obj.data_dir)
xyz_files = dir(fullfile(obj.data_dir, '*.xyz'));
if ~isempty(xyz_files)
obj = obj.load_xyz_files(xyz_files, n_max);
obj.is_real_data = true;
return;
end
end
fprintf('No .xyz files in "%s". Generating synthetic molecules.\n', obj.data_dir);
obj = obj.generate_synthetic(min(n_max, 1000));
end
function obj = load_xyz_files(obj, xyz_files, n_max)
n_files = min(length(xyz_files), n_max);
fprintf('Loading %d QM9 .xyz files...\n', n_files);
t0 = tic;
obj.coords = cell(n_files,1); obj.atomic_numbers = cell(n_files,1);
obj.properties_mat = zeros(n_files, 15);
elem_map = containers.Map({'H','C','N','O','F','S','Cl','Br','I','P','Si','B'}, ...
{1,6,7,8,9,16,17,35,53,15,14,5});
valid = 0;
for f = 1:n_files
fpath = fullfile(xyz_files(f).folder, xyz_files(f).name);
fid = -1;
try
fid = fopen(fpath,'r');
n_atoms = str2double(fgetl(fid));
if isnan(n_atoms) || n_atoms < 1, fclose(fid); continue; end
prop_line = strrep(fgetl(fid),'*^','e');
tokens = strsplit(strtrim(prop_line));
props = zeros(1,min(15,length(tokens)-2));
for t=1:length(props)
val = str2double(tokens{t+2});
if ~isnan(val), props(t)=val; end
end
Z = zeros(n_atoms,1); pos = zeros(n_atoms,3);
for a = 1:n_atoms
line = fgetl(fid);
if ~ischar(line), break; end
% QM9 uses tabs and may have extra columns
line = strrep(line, char(9), ' ');
parts = strsplit(strtrim(line));
if length(parts) < 4, break; end
elem = parts{1};
if isKey(elem_map,elem), Z(a)=elem_map(elem); else, Z(a)=6; end
pos(a,:) = [str2double(strrep(parts{2},'*^','e')), ...
str2double(strrep(parts{3},'*^','e')), ...
str2double(strrep(parts{4},'*^','e'))];
end
fclose(fid);
if any(isnan(pos(:))), continue; end
valid = valid+1;
obj.coords{valid} = pos;
obj.atomic_numbers{valid} = Z;
obj.properties_mat(valid,1:length(props)) = props;
catch ME
if fid > 0, fclose(fid); end
end
if mod(f,2000)==0, fprintf(' %d / %d (%.1fs)\n', f, n_files, toc(t0)); end
end
obj.coords = obj.coords(1:valid); obj.atomic_numbers = obj.atomic_numbers(1:valid);
obj.properties_mat = obj.properties_mat(1:valid,:); obj.n_molecules = valid;
fprintf('Loaded %d molecules in %.1fs.\n', valid, toc(t0));
end
function obj = generate_synthetic(obj, n_mol)
fprintf('Generating %d synthetic molecules...\n', n_mol);
rng(42,'twister'); atom_types = [1,6,7,8,9];
obj.coords = cell(n_mol,1); obj.atomic_numbers = cell(n_mol,1);
obj.properties_mat = zeros(n_mol, 15);
n_rot = obj.n_rotations;
angles = (0:n_rot-1)*2*pi/n_rot;
e1 = [cos(angles); sin(angles); zeros(1,n_rot)];
for i = 1:n_mol
na = randi([4,12]); pos = randn(na,3)*1.5; pos = pos-mean(pos,1);
Z = atom_types(randi(5,na,1))';
obj.coords{i} = pos; obj.atomic_numbers{i} = Z;
D = squareform(pdist(pos)); dd=D(D>0);
w = Z/(sum(Z)+1e-10);
pw = abs(fft(w'*pos*e1)).^2;
obj.properties_mat(i,8) = mean(dd) + 0.003*pw(2) + 0.002*pw(3) + sum(Z.^2)/500;
end
obj.n_molecules = n_mol;
fprintf('Target = d_mean + angular_freq_power + Z_term\n');
end
%% ANGULAR PROJECTION FEATURES ====================================
function obj = compute_rotated_features(obj, varargin)
p = inputParser; addParameter(p,'n_feat',48); parse(p,varargin{:});
n_feat_target = p.Results.n_feat;
n_mol = obj.n_molecules; n_rot = obj.n_rotations;
angles = (0:n_rot-1)*2*pi/n_rot;
feat0 = obj.angular_features(obj.coords{1}, obj.atomic_numbers{1}, angles, n_feat_target);
n_feat = size(feat0,1); obj.n_feat_per_rot = n_feat;
X = zeros(n_mol, n_feat, n_rot);
t0 = tic;
fprintf('Computing features: %d molecules x %d rotations...\n', n_mol, n_rot);
for mol = 1:n_mol
pos = obj.coords{mol} - mean(obj.coords{mol},1);
X(mol,:,:) = obj.angular_features(pos, obj.atomic_numbers{mol}, angles, n_feat_target);
if mod(mol,500)==0, fprintf(' %d/%d (%.1fs)\n', mol, n_mol, toc(t0)); end
end
obj.X_tensor = X;
fprintf('Feature tensor: %d x %d x %d (%.1fs)\n', size(X), toc(t0));
obj.verify_cyclic_structure();
end
function F = angular_features(~, coords, Z, angles, n_feat_target)
n_rot = length(angles); n_atoms = size(coords,1);
Zn = Z(:); w = Zn/(sum(Zn)+1e-10);
coords_c = coords - mean(coords, 1);
feat_list = {};
e1 = [cos(angles); sin(angles); zeros(1,n_rot)];
e2 = [-sin(angles); cos(angles); zeros(1,n_rot)];
p1 = coords_c*e1; p2 = coords_c*e2;
pz = coords_c*repmat([0;0;1],1,n_rot);
% Equivariant moments
feat_list{end+1}=w'*p1; feat_list{end+1}=w'*p2; feat_list{end+1}=w'*pz;
feat_list{end+1}=w'*(p1.^2); feat_list{end+1}=w'*(p2.^2); feat_list{end+1}=w'*(pz.^2);
feat_list{end+1}=w'*(p1.*p2); feat_list{end+1}=w'*(p1.*pz); feat_list{end+1}=w'*(p2.*pz);
feat_list{end+1}=Zn'*p1; feat_list{end+1}=Zn'*p2;
feat_list{end+1}=Zn'*(p1.^2); feat_list{end+1}=Zn'*(p2.^2);
feat_list{end+1}=Zn'*(p1.*p2); feat_list{end+1}=Zn'*(p1.*pz);
feat_list{end+1}=w'*(p1.^3); feat_list{end+1}=w'*(p2.^3);
feat_list{end+1}=Zn'*(p1.^3); feat_list{end+1}=Zn'*(p2.^3);
[~,si]=sort(Z,'descend');
for k=1:min(4,n_atoms)
idx=si(k);
feat_list{end+1}=p1(idx,:); feat_list{end+1}=p2(idx,:); feat_list{end+1}=pz(idx,:);
end
for pp=1:min(3,n_atoms-1)
ii=si(1); jj=si(min(pp+1,n_atoms));
d_ij=norm(coords_c(ii,:)-coords_c(jj,:))+1e-8;
feat_list{end+1}=Z(ii)*Z(jj)/d_ij*(p1(ii,:)-p1(jj,:));
feat_list{end+1}=Z(ii)*Z(jj)/d_ij*(p2(ii,:)-p2(jj,:));
end
% Invariant features
if n_atoms>=2
D=pdist(coords_c);
for v=[mean(D),std(D),min(D),max(D)], feat_list{end+1}=repmat(v,1,n_rot); end
else
for k=1:4, feat_list{end+1}=zeros(1,n_rot); end
end
r=sqrt(sum(coords_c.^2,2)); rs=sort(r,'descend');
for k=1:min(4,n_atoms), feat_list{end+1}=repmat(rs(k),1,n_rot); end
for k=n_atoms+1:4, feat_list{end+1}=zeros(1,n_rot); end
feat_list{end+1}=repmat(sum(Zn.^2)/100,1,n_rot);
feat_list{end+1}=repmat(mean(Zn),1,n_rot);
feat_list{end+1}=repmat(n_atoms,1,n_rot);
F=cell2mat(feat_list');
nr=size(F,1);
if nr<n_feat_target, F=[F;zeros(n_feat_target-nr,n_rot)];
elseif nr>n_feat_target, F=F(1:n_feat_target,:); end
end
function verify_cyclic_structure(obj)
if obj.n_molecules<1, return; end
mi=min(5,obj.n_molecules);
pos=obj.coords{mi}-mean(obj.coords{mi},1);
Z=obj.atomic_numbers{mi}; nr=obj.n_rotations; ang=(0:nr-1)*2*pi/nr;
F0=obj.angular_features(pos,Z,ang,obj.n_feat_per_rot);
th=2*pi/nr; R1=[cos(th),-sin(th),0;sin(th),cos(th),0;0,0,1];
Fr=obj.angular_features((R1*pos')',Z,ang,obj.n_feat_per_rot);
err = norm(Fr(:)-reshape(circshift(F0,1,2),[],1))/(norm(F0(:))+1e-20);
fprintf('Cyclic check: %.2e', err);
if err<1e-10, fprintf(' PASS\n\n'); else, fprintf(' (%.2e)\n\n', err); end
end
%% COMPARISON =====================================================
function results = run_comparison(obj, target_col, varargin)
p=inputParser; addParameter(p,'n_seeds',3); addParameter(p,'split',[0.7,0.15,0.15]);
parse(p, varargin{:});
n_seeds=p.Results.n_seeds; sp=p.Results.split;
y=obj.properties_mat(:,target_col); n=obj.n_molecules;
mn={'starG_SVD_Ridge','Neural_starG','Standard_MLP','Invariant_MLP','Augmented_MLP'};
nm=length(mn);
R2_all=zeros(nm,n_seeds); RMSE_all=zeros(nm,n_seeds);
rot_var=zeros(nm,n_seeds); params=zeros(nm,1);
fprintf('\n============================================================\n');
fprintf(' Comparison: col=%d, %d mol, %d seeds\n', target_col, n, n_seeds);
fprintf('============================================================\n');
for seed=1:n_seeds
fprintf('\n, Seed %d/%d , \n', seed, n_seeds);
rng(seed*111,'twister'); idx=randperm(n);
ntr=round(sp(1)*n); nva=round(sp(2)*n);
tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
Xtr=obj.X_tensor(tri,:,:); Xva=obj.X_tensor(vai,:,:); Xte=obj.X_tensor(tei,:,:);
ytr=y(tri); yva=y(vai); yte=y(tei);
% 1. star_G-SVD + Ridge
t0=tic; fprintf(' [1] star_G-SVD + Ridge...');
[ftr,np]=extractStarGFeatures(Xtr,obj.G,obj.n_feat_per_rot);
fva=extractStarGFeatures(Xva,obj.G,obj.n_feat_per_rot,np);
fte=extractStarGFeatures(Xte,obj.G,obj.n_feat_per_rot,np);
[w1,~]=obj.ridge_cv(ftr,ytr,fva,yva); yp1=fte*w1;
R2_all(1,seed)=obj.r2(yte,yp1); RMSE_all(1,seed)=sqrt(mean((yte-yp1).^2));
rot_var(1,seed)=obj.measure_rotation_variance(Xte,np,w1);
params(1)=numel(w1);
fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(1,seed), rot_var(1,seed), toc(t0));
% 2. Neural star_G
t0=tic; fprintf(' [2] Neural star_G...');
nin=size(ftr,2); [wm,bm]=obj.train_mlp(ftr,ytr,fva,yva,[nin,64,32,1],300,0.003);
yp2=obj.predict_mlp(fte,wm,bm);
R2_all(2,seed)=obj.r2(yte,yp2); RMSE_all(2,seed)=sqrt(mean((yte-yp2).^2));
rot_var(2,seed)=obj.measure_rotation_variance(Xte,np,[],wm,bm);
params(2)=sum(cellfun(@numel,wm))+sum(cellfun(@numel,bm));
fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(2,seed), rot_var(2,seed), toc(t0));
% 3. Standard MLP
t0=tic; fprintf(' [3] Standard MLP...');
Xs1=squeeze(Xtr(:,:,1)); Xsv=squeeze(Xva(:,:,1)); Xst=squeeze(Xte(:,:,1));
[mu3,s3]=deal(mean(Xs1),std(Xs1)+1e-8);
[w3,b3]=obj.train_mlp((Xs1-mu3)./s3,ytr,(Xsv-mu3)./s3,yva,[size(Xs1,2),64,32,1],300,0.003);
yp3=obj.predict_mlp((Xst-mu3)./s3,w3,b3);
R2_all(3,seed)=obj.r2(yte,yp3); RMSE_all(3,seed)=sqrt(mean((yte-yp3).^2));
rot_var(3,seed)=obj.measure_rv_raw(Xte,mu3,s3,w3,b3);
params(3)=sum(cellfun(@numel,w3))+sum(cellfun(@numel,b3));
fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(3,seed), rot_var(3,seed), toc(t0));
% 4. Invariant MLP
t0=tic; fprintf(' [4] Invariant MLP...');
fi_tr=[mean(Xtr,3),std(Xtr,0,3),min(Xtr,[],3),max(Xtr,[],3)];
fi_va=[mean(Xva,3),std(Xva,0,3),min(Xva,[],3),max(Xva,[],3)];
fi_te=[mean(Xte,3),std(Xte,0,3),min(Xte,[],3),max(Xte,[],3)];
[mu4,s4]=deal(mean(fi_tr),std(fi_tr)+1e-8);
fi_tr=(fi_tr-mu4)./s4; fi_va=(fi_va-mu4)./s4; fi_te=(fi_te-mu4)./s4;
[w4,b4]=obj.train_mlp(fi_tr,ytr,fi_va,yva,[size(fi_tr,2),64,32,1],300,0.003);
yp4=obj.predict_mlp(fi_te,w4,b4);
R2_all(4,seed)=obj.r2(yte,yp4); RMSE_all(4,seed)=sqrt(mean((yte-yp4).^2));
rot_var(4,seed)=0; params(4)=sum(cellfun(@numel,w4))+sum(cellfun(@numel,b4));
fprintf(' R2=%.4f (%.1fs)\n', R2_all(4,seed), toc(t0));
% 5. Augmented MLP
t0=tic; fprintf(' [5] Augmented MLP...');
nf=obj.n_feat_per_rot;
Xa5=reshape(permute(Xtr,[1,3,2]),[],nf); ya5=repmat(ytr,obj.n_rotations,1);
Xav=reshape(permute(Xva,[1,3,2]),[],nf); yav=repmat(yva,obj.n_rotations,1);
[mu5,s5]=deal(mean(Xa5),std(Xa5)+1e-8);
[w5,b5]=obj.train_mlp((Xa5-mu5)./s5,ya5,(Xav-mu5)./s5,yav,[nf,64,32,1],300,0.003);
yp5=obj.predict_mlp((squeeze(Xte(:,:,1))-mu5)./s5,w5,b5);
R2_all(5,seed)=obj.r2(yte,yp5); RMSE_all(5,seed)=sqrt(mean((yte-yp5).^2));
rot_var(5,seed)=obj.measure_rv_raw(Xte,mu5,s5,w5,b5);
params(5)=sum(cellfun(@numel,w5))+sum(cellfun(@numel,b5));
fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(5,seed), rot_var(5,seed), toc(t0));
end
fprintf('\n============================================================\n');
fprintf(' RESULTS (mean +/- std, %d seeds)\n', n_seeds);
fprintf('============================================================\n');
fprintf('%-22s %12s %12s %12s %8s\n','Method','Test R2','RMSE','Rot Var','Params');
fprintf('%s\n', repmat('-',1,70));
for m=1:nm
rv=rot_var(m,:);
if m==4, rvs='~0 (exact)'; else, rvs=sprintf('%.2e',mean(rv)); end
fprintf('%-22s %5.3f+/-%5.3f %5.3f+/-%5.3f %12s %8d\n', ...
mn{m},mean(R2_all(m,:)),std(R2_all(m,:)),mean(RMSE_all(m,:)),std(RMSE_all(m,:)),rvs,params(m));
end
results.method_names=mn; results.R2=R2_all; results.RMSE=RMSE_all;
results.rot_var=rot_var; results.params=params;
obj.results=results;
end
%% ROTATION VARIANCE ==============================================
function rv = measure_rotation_variance(obj,Xt,np,wr,wm,bm)
if nargin<5, wm=[]; end; if nargin<6, bm=[]; end
nt=min(50,size(Xt,1)); % reduced for speed
pps=zeros(nt,obj.n_rotations);
for g=1:obj.n_rotations
Xs=circshift(Xt(1:nt,:,:),g-1,3);
fg=extractStarGFeatures(Xs,obj.G,obj.n_feat_per_rot,np);
if ~isempty(wm), pps(:,g)=obj.predict_mlp(fg,wm,bm);
else, pps(:,g)=fg*wr; end
end
rv=mean(var(pps,0,2));
end
function rv = measure_rv_raw(obj,Xt,mu,sig,W,B)
nt=min(50,size(Xt,1));
pps=zeros(nt,obj.n_rotations);
for g=1:obj.n_rotations
Xg=squeeze(Xt(1:nt,:,g));
pps(:,g)=obj.predict_mlp((Xg-mu)./sig,W,B);
end
rv=mean(var(pps,0,2));
end
%% ML UTILITIES ===================================================
function [w,bl] = ridge_cv(~,Xr,yr,Xv,yv)
lams=[0.001,0.01,0.1,1,10,100,1000]; be=Inf; bl=0.01; pp=size(Xr,2);
R=eye(pp); R(1,1)=0;
for lam=lams, wt=(Xr'*Xr+lam*R)\(Xr'*yr);
e=mean((yv-Xv*wt).^2); if e<be, be=e; bl=lam; end; end
w=(Xr'*Xr+bl*R)\(Xr'*yr);
end
function [W,B] = train_mlp(~,X,y,Xv,yv,layers,maxep,lr)
nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
for l=1:nl, fi=layers(l); W{l}=randn(fi,layers(l+1))*sqrt(2/fi); B{l}=zeros(1,layers(l+1)); end
mW=cellfun(@(w)zeros(size(w)),W,'Uni',0); vW=mW;
mB=cellfun(@(b)zeros(size(b)),B,'Uni',0); vB=mB;
b1=0.9;b2=0.999;ea=1e-8; bv=Inf; pa=25; wa=0; Wb=W; Bb=B;
nn=size(X,1); bs=min(256,nn);
for ep=1:maxep
pm=randperm(nn);
for s=1:bs:nn
bi=pm(s:min(s+bs-1,nn)); Xb=X(bi,:); yb=y(bi);
A=cell(nl+1,1); A{1}=Xb;
for l=1:nl, Zl=A{l}*W{l}+B{l};
if l<nl, A{l+1}=max(0,Zl); else, A{l+1}=Zl; end; end
dZ=(A{nl+1}-yb)/size(Xb,1);
for l=nl:-1:1
gW=A{l}'*dZ; gB=sum(dZ,1);
if l>1, dZ=(dZ*W{l}').*(A{l}>0); end
mW{l}=b1*mW{l}+(1-b1)*gW; vW{l}=b2*vW{l}+(1-b2)*gW.^2;
mB{l}=b1*mB{l}+(1-b1)*gB; vB{l}=b2*vB{l}+(1-b2)*gB.^2;
W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
end
end
yp=Xv; for l=1:nl, yp=yp*W{l}+B{l}; if l<nl, yp=max(0,yp); end; end
vm=mean((yv-yp).^2);
if vm<bv, bv=vm; wa=0; Wb=W; Bb=B; else, wa=wa+1; if wa>=pa, break; end; end
end
W=Wb; B=Bb;
end
function yp = predict_mlp(~,X,W,B)
yp=X; for l=1:length(W), yp=yp*W{l}+B{l}; if l<length(W), yp=max(0,yp); end; end
end
function v = r2(~,yt,yp)
v=1-sum((yt-yp).^2)/(sum((yt-mean(yt)).^2)+1e-20);
end
end
end