tensor-group-sym / experiments / run_qm9_nature.m
run_qm9_nature.m
Raw
%% ========================================================================
%% run_qm9_nature.m
%% Master script: full Nature-paper pipeline
%%
%% LAPTOP-FRIENDLY DEFAULTS:
%%   - 1000 molecules (increase for publication-quality results)
%%   - Part 3 (learned algebra) skipped by default (experimental, slow)
%%
%% Usage:
%%   >> run_qm9_nature                                        % synthetic, 1000 mol
%%   >> run_qm9_nature('qm9_dir', '/path/to/qm9/xyz/')       % real QM9
%%   >> run_qm9_nature('n_molecules', 5000, 'skip_part', [])  % full run
%%
%% For real QM9 data:
%%   1. Download dsgdb9nsd.xyz.tar.bz2 from
%%      https://figshare.com/collections/978904
%%   2. Extract .xyz files into a folder
%%   3. Point qm9_dir at that folder
%%
%% LH & Claude 2026
%% ========================================================================

function run_qm9_nature(varargin)

    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    p = inputParser;
    addParameter(p, 'qm9_dir', 'data/qm9', @ischar);
    addParameter(p, 'n_molecules', 1000, @isnumeric);   % laptop-friendly
    addParameter(p, 'n_rotations', 12, @isnumeric);
    addParameter(p, 'target_col', 8, @isnumeric);
    addParameter(p, 'n_seeds', 3, @isnumeric);
    addParameter(p, 'results_dir', 'results/qm9_nature', @ischar);
    addParameter(p, 'skip_part', [3], @isnumeric);      % skip Part 3 by default
    parse(p, varargin{:});
    opts = p.Results;

    if ~exist(opts.results_dir, 'dir'), mkdir(opts.results_dir); end

    fprintf('\n================================================================\n');
    fprintf('  star_G Tensor Algebra: Nature Paper Pipeline\n');
    fprintf('================================================================\n');
    fprintf('  QM9 data dir   : %s\n', opts.qm9_dir);
    fprintf('  n_molecules    : %d\n', opts.n_molecules);
    fprintf('  n_rotations    : %d (Z_%d)\n', opts.n_rotations, opts.n_rotations);
    fprintf('  target column  : %d\n', opts.target_col);
    fprintf('  seeds          : %d\n', opts.n_seeds);
    fprintf('  skip parts     : %s\n', mat2str(opts.skip_part));
    fprintf('  results dir    : %s\n', opts.results_dir);
    fprintf('================================================================\n\n');

    total_t0 = tic;

    %% PART 1: Comparison Study ==========================================
    if ~ismember(1, opts.skip_part)
        fprintf('########################################################\n');
        fprintf('#  PART 1: Comparison Study                            #\n');
        fprintf('########################################################\n');
        t1 = tic;

        exp = QM9_experiment(opts.qm9_dir, opts.n_rotations, ...
            'n_molecules', opts.n_molecules);
        exp = exp.load_data(opts.n_molecules);
        exp = exp.compute_rotated_features('n_feat', 48);
        results = exp.run_comparison(opts.target_col, 'n_seeds', opts.n_seeds);

        save(fullfile(opts.results_dir, 'comparison_results.mat'), 'results', 'opts');

        try
            fig1 = figure('Position', [100,100,1000,400]);
            subplot(1,2,1);
            R2m=mean(results.R2,2); R2s=std(results.R2,0,2);
            bar(R2m); hold on; errorbar(1:5,R2m,R2s,'k.','LineWidth',1.5);
            set(gca,'XTickLabel',strrep(results.method_names,'_',' '),'XTickLabelRotation',30);
            ylabel('Test R^2'); title('Predictive Performance');
            ylim([min(0,min(R2m)-0.1), 1.05]); grid on;
            subplot(1,2,2);
            rv=mean(results.rot_var,2); rv(rv==0)=1e-32;
            bar(log10(rv+1e-33));
            set(gca,'XTickLabel',strrep(results.method_names,'_',' '),'XTickLabelRotation',30);
            ylabel('log_{10}(Rot Var)'); title('Invariance'); grid on;
            saveas(fig1, fullfile(opts.results_dir, 'comparison_figure.png'));
            fprintf('Figure saved.\n');
        catch, end
        fprintf('Part 1 done in %.1fs.\n\n', toc(t1));
    end

    %% PART 2: Symmetry Discovery ========================================
    if ~ismember(2, opts.skip_part)
        fprintf('########################################################\n');
        fprintf('#  PART 2: Symmetry Discovery                          #\n');
        fprintf('########################################################\n');
        t2 = tic;

        if ~exist('exp','var') || isempty(exp.X_tensor)
            exp = QM9_experiment(opts.qm9_dir, opts.n_rotations);
            exp = exp.load_data(opts.n_molecules);
            exp = exp.compute_rotated_features('n_feat', 48);
        end

        sd = symmetry_discovery(exp.X_tensor, exp.properties_mat(:, opts.target_col));
        [~, report] = sd.discover('max_order', opts.n_rotations, ...
            'test_dihedral', true, 'test_klein4', true, ...
            'test_quaternion', false, 'supervised_weight', 0.4);

        save(fullfile(opts.results_dir, 'symmetry_discovery.mat'), 'report');

        try
            fig2 = figure('Position', [100,100,800,400]);
            names={report.name}; scores=[report.combined_score];
            [ss,si]=sort(scores,'descend'); ns=names(si);
            bar(ss,'FaceColor',[0.3,0.6,0.9]); hold on;
            bar(1,ss(1),'FaceColor',[0.9,0.3,0.2]);
            set(gca,'XTickLabel',ns,'XTickLabelRotation',45);
            ylabel('Score'); title(sprintf('Best = %s', sd.best_group_name)); grid on;
            saveas(fig2, fullfile(opts.results_dir, 'symmetry_landscape.png'));
            fprintf('Figure saved.\n');
        catch, end

        fprintf('\nDetailed Report:\n');
        fprintf('%-10s %6s %10s %10s %8s %10s\n','Group','Order','Compress','Sparsity','Pred R2','Score');
        fprintf('%s\n', repmat('-',1,60));
        [~,si2]=sort([report.combined_score],'descend');
        for i=si2
            r2s='N/A';
            if ~isnan(report(i).prediction_r2), r2s=sprintf('%+.3f',report(i).prediction_r2); end
            fprintf('%-10s %6d %10.4f %10.4f %8s %10.4f\n', ...
                report(i).name,report(i).order,report(i).svd_compress, ...
                report(i).fourier_sparsity,r2s,report(i).combined_score);
        end
        fprintf('Part 2 done in %.1fs.\n\n', toc(t2));
    end

    %% PART 3: Learned Algebra ===========================================
    if ~ismember(3, opts.skip_part)
        fprintf('########################################################\n');
        fprintf('#  PART 3: Learned Algebra (Experimental)              #\n');
        fprintf('########################################################\n');
        t3 = tic;

        if ~exist('sd','var')
            if ~exist('exp','var') || isempty(exp.X_tensor)
                exp = QM9_experiment(opts.qm9_dir, opts.n_rotations);
                exp = exp.load_data(opts.n_molecules);
                exp = exp.compute_rotated_features('n_feat', 48);
            end
            sd = symmetry_discovery(exp.X_tensor, exp.properties_mat(:, opts.target_col));
        end

        [F_learned, C_learned, err] = sd.learn_algebra('max_iter', 100, 'lambda', 0.02);

        n = size(C_learned,1);
        C_cyclic = zeros(n,n,n); for k=1:n, C_cyclic(k,k,k)=1; end
        fprintf('  Z_%d distance: %.4f\n', n, norm(abs(C_learned(:))-abs(C_cyclic(:))));
        save(fullfile(opts.results_dir, 'learned_algebra.mat'), 'F_learned', 'C_learned', 'err');
        fprintf('Part 3 done in %.1fs.\n\n', toc(t3));
    end

    fprintf('================================================================\n');
    fprintf('  Total time: %.1fs. Results in: %s\n', toc(total_t0), opts.results_dir);
    fprintf('================================================================\n');
end