%% ========================================================================
%% run_qm9_nature.m
%% Master script: full Nature-paper pipeline
%%
%% LAPTOP-FRIENDLY DEFAULTS:
%% - 1000 molecules (increase for publication-quality results)
%% - Part 3 (learned algebra) skipped by default (experimental, slow)
%%
%% Usage:
%% >> run_qm9_nature % synthetic, 1000 mol
%% >> run_qm9_nature('qm9_dir', '/path/to/qm9/xyz/') % real QM9
%% >> run_qm9_nature('n_molecules', 5000, 'skip_part', []) % full run
%%
%% For real QM9 data:
%% 1. Download dsgdb9nsd.xyz.tar.bz2 from
%% https://figshare.com/collections/978904
%% 2. Extract .xyz files into a folder
%% 3. Point qm9_dir at that folder
%%
%% LH & Claude 2026
%% ========================================================================
function run_qm9_nature(varargin)
thisDir = fileparts(mfilename('fullpath'));
rootDir = fileparts(thisDir);
addpath(fullfile(rootDir, 'core'));
addpath(fullfile(rootDir, 'experiments'));
p = inputParser;
addParameter(p, 'qm9_dir', 'data/qm9', @ischar);
addParameter(p, 'n_molecules', 1000, @isnumeric); % laptop-friendly
addParameter(p, 'n_rotations', 12, @isnumeric);
addParameter(p, 'target_col', 8, @isnumeric);
addParameter(p, 'n_seeds', 3, @isnumeric);
addParameter(p, 'results_dir', 'results/qm9_nature', @ischar);
addParameter(p, 'skip_part', [3], @isnumeric); % skip Part 3 by default
parse(p, varargin{:});
opts = p.Results;
if ~exist(opts.results_dir, 'dir'), mkdir(opts.results_dir); end
fprintf('\n================================================================\n');
fprintf(' star_G Tensor Algebra: Nature Paper Pipeline\n');
fprintf('================================================================\n');
fprintf(' QM9 data dir : %s\n', opts.qm9_dir);
fprintf(' n_molecules : %d\n', opts.n_molecules);
fprintf(' n_rotations : %d (Z_%d)\n', opts.n_rotations, opts.n_rotations);
fprintf(' target column : %d\n', opts.target_col);
fprintf(' seeds : %d\n', opts.n_seeds);
fprintf(' skip parts : %s\n', mat2str(opts.skip_part));
fprintf(' results dir : %s\n', opts.results_dir);
fprintf('================================================================\n\n');
total_t0 = tic;
%% PART 1: Comparison Study ==========================================
if ~ismember(1, opts.skip_part)
fprintf('########################################################\n');
fprintf('# PART 1: Comparison Study #\n');
fprintf('########################################################\n');
t1 = tic;
exp = QM9_experiment(opts.qm9_dir, opts.n_rotations, ...
'n_molecules', opts.n_molecules);
exp = exp.load_data(opts.n_molecules);
exp = exp.compute_rotated_features('n_feat', 48);
results = exp.run_comparison(opts.target_col, 'n_seeds', opts.n_seeds);
save(fullfile(opts.results_dir, 'comparison_results.mat'), 'results', 'opts');
try
fig1 = figure('Position', [100,100,1000,400]);
subplot(1,2,1);
R2m=mean(results.R2,2); R2s=std(results.R2,0,2);
bar(R2m); hold on; errorbar(1:5,R2m,R2s,'k.','LineWidth',1.5);
set(gca,'XTickLabel',strrep(results.method_names,'_',' '),'XTickLabelRotation',30);
ylabel('Test R^2'); title('Predictive Performance');
ylim([min(0,min(R2m)-0.1), 1.05]); grid on;
subplot(1,2,2);
rv=mean(results.rot_var,2); rv(rv==0)=1e-32;
bar(log10(rv+1e-33));
set(gca,'XTickLabel',strrep(results.method_names,'_',' '),'XTickLabelRotation',30);
ylabel('log_{10}(Rot Var)'); title('Invariance'); grid on;
saveas(fig1, fullfile(opts.results_dir, 'comparison_figure.png'));
fprintf('Figure saved.\n');
catch, end
fprintf('Part 1 done in %.1fs.\n\n', toc(t1));
end
%% PART 2: Symmetry Discovery ========================================
if ~ismember(2, opts.skip_part)
fprintf('########################################################\n');
fprintf('# PART 2: Symmetry Discovery #\n');
fprintf('########################################################\n');
t2 = tic;
if ~exist('exp','var') || isempty(exp.X_tensor)
exp = QM9_experiment(opts.qm9_dir, opts.n_rotations);
exp = exp.load_data(opts.n_molecules);
exp = exp.compute_rotated_features('n_feat', 48);
end
sd = symmetry_discovery(exp.X_tensor, exp.properties_mat(:, opts.target_col));
[~, report] = sd.discover('max_order', opts.n_rotations, ...
'test_dihedral', true, 'test_klein4', true, ...
'test_quaternion', false, 'supervised_weight', 0.4);
save(fullfile(opts.results_dir, 'symmetry_discovery.mat'), 'report');
try
fig2 = figure('Position', [100,100,800,400]);
names={report.name}; scores=[report.combined_score];
[ss,si]=sort(scores,'descend'); ns=names(si);
bar(ss,'FaceColor',[0.3,0.6,0.9]); hold on;
bar(1,ss(1),'FaceColor',[0.9,0.3,0.2]);
set(gca,'XTickLabel',ns,'XTickLabelRotation',45);
ylabel('Score'); title(sprintf('Best = %s', sd.best_group_name)); grid on;
saveas(fig2, fullfile(opts.results_dir, 'symmetry_landscape.png'));
fprintf('Figure saved.\n');
catch, end
fprintf('\nDetailed Report:\n');
fprintf('%-10s %6s %10s %10s %8s %10s\n','Group','Order','Compress','Sparsity','Pred R2','Score');
fprintf('%s\n', repmat('-',1,60));
[~,si2]=sort([report.combined_score],'descend');
for i=si2
r2s='N/A';
if ~isnan(report(i).prediction_r2), r2s=sprintf('%+.3f',report(i).prediction_r2); end
fprintf('%-10s %6d %10.4f %10.4f %8s %10.4f\n', ...
report(i).name,report(i).order,report(i).svd_compress, ...
report(i).fourier_sparsity,r2s,report(i).combined_score);
end
fprintf('Part 2 done in %.1fs.\n\n', toc(t2));
end
%% PART 3: Learned Algebra ===========================================
if ~ismember(3, opts.skip_part)
fprintf('########################################################\n');
fprintf('# PART 3: Learned Algebra (Experimental) #\n');
fprintf('########################################################\n');
t3 = tic;
if ~exist('sd','var')
if ~exist('exp','var') || isempty(exp.X_tensor)
exp = QM9_experiment(opts.qm9_dir, opts.n_rotations);
exp = exp.load_data(opts.n_molecules);
exp = exp.compute_rotated_features('n_feat', 48);
end
sd = symmetry_discovery(exp.X_tensor, exp.properties_mat(:, opts.target_col));
end
[F_learned, C_learned, err] = sd.learn_algebra('max_iter', 100, 'lambda', 0.02);
n = size(C_learned,1);
C_cyclic = zeros(n,n,n); for k=1:n, C_cyclic(k,k,k)=1; end
fprintf(' Z_%d distance: %.4f\n', n, norm(abs(C_learned(:))-abs(C_cyclic(:))));
save(fullfile(opts.results_dir, 'learned_algebra.mat'), 'F_learned', 'C_learned', 'err');
fprintf('Part 3 done in %.1fs.\n\n', toc(t3));
end
fprintf('================================================================\n');
fprintf(' Total time: %.1fs. Results in: %s\n', toc(total_t0), opts.results_dir);
fprintf('================================================================\n');
end