tensor-group-sym / experiments / discover_symmetry_from_data.m
discover_symmetry_from_data.m
Raw
%% ========================================================================
%% discover_symmetry_from_data.m
%% Discover property-dependent irrep structure via the octahedral group
%%
%% KEY IDEA: Replace the cyclic group Z_24 (single rotation axis) with
%% the chiral octahedral group O (24 elements, subgroup of SO(3)).
%% The octahedral group has 5 irreps with dimensions 1,1,2,3,3 that
%% correspond to angular momentum channels l=0,0,2,1,2.
%%
%% WIGNER-ECKART PREDICTION:
%%   Scalar properties (gap, HOMO, LUMO, ZPVE):
%%     -> Depend only on A1 (trivial irrep, l=0)
%%   Dipole moment |mu| (magnitude of a vector):
%%     -> Needs T1 (l=1 irrep) to capture directional information
%%   Polarizability alpha (trace of rank-2 tensor):
%%     -> Needs A1 + E (l=0 + l=2 components)
%%
%% If the data shows T1 active for dipole but not for energies,
%% we've empirically recovered selection rules from geometry alone.
%%
%% Usage:
%%   >> discover_symmetry_from_data('qm9_dir', './data/xyz/')
%%
%% LH & Claude 2026
%% ========================================================================

function results = discover_symmetry_from_data(varargin)

    set(groot,'defaultFigureColor','w','defaultAxesColor','w',...
        'defaultAxesXColor','k','defaultAxesYColor','k','defaultTextColor','k');

    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    pa = inputParser;
    addParameter(pa, 'qm9_dir', './data/xyz/', @ischar);
    addParameter(pa, 'n_molecules', 2000, @isnumeric);
    addParameter(pa, 'results_dir', 'results/symmetry_discovery', @ischar);
    parse(pa, varargin{:});
    opts = pa.Results;
    if ~exist(opts.results_dir,'dir'), mkdir(opts.results_dir); end

    fprintf('\n================================================================\n');
    fprintf('  Wigner-Eckart Discovery via Octahedral Group\n');
    fprintf('  %d molecules, 24 orientations (chiral octahedral O)\n', opts.n_molecules);
    fprintf('================================================================\n\n');

    %% 1. Build the octahedral group
    [R_mats, mult_table, irrep_info] = build_octahedral_group();
    ng = 24;

    fprintf('Octahedral group O: %d elements, 5 irreps\n', ng);
    fprintf('  Irrep    Dim  l-channel  Description\n');
    for i = 1:length(irrep_info)
        fprintf('  %-6s   %d    l=%d        %s\n', ...
            irrep_info(i).name, irrep_info(i).dim, ...
            irrep_info(i).l, irrep_info(i).desc);
    end

    % Build StarGAlgebra with the proper octahedral irreps. The previous
    % 'table' constructor used fft(eye(24)) for F, which is mathematically
    % wrong for the non-abelian octahedral group; the dedicated
    % 'octahedral' constructor builds F from the actual A1, A2, E, T1, T2
    % irrep matrices.
    G_oct = StarGAlgebra('octahedral');

    %% 2. Load QM9 data
    % Use a simple QM9 loader (bypass QM9_experiment since we need custom rotations)
    [coords, Z_all, charges_all, props, n_mol] = load_qm9_simple(opts.qm9_dir, opts.n_molecules);

    %% 2b. Compute dipole vectors from Mulliken charges
    % mu_vec = sum_i q_i * r_i (in the original molecular orientation)
    mu_vec = zeros(n_mol, 3);
    for i = 1:n_mol
        q = charges_all{i};
        pos = coords{i};
        mu_vec(i,:) = q' * pos;  % 1x3 dipole vector
    end
    fprintf('Dipole vectors computed from Mulliken charges.\n');
    fprintf('  |mu| range: [%.4f, %.4f], mean: %.4f\n', ...
        min(vecnorm(mu_vec,2,2)), max(vecnorm(mu_vec,2,2)), mean(vecnorm(mu_vec,2,2)));
    fprintf('  mu_x range: [%.4f, %.4f]\n', min(mu_vec(:,1)), max(mu_vec(:,1)));
    fprintf('  mu_y range: [%.4f, %.4f]\n', min(mu_vec(:,2)), max(mu_vec(:,2)));
    fprintf('  mu_z range: [%.4f, %.4f]\n', min(mu_vec(:,3)), max(mu_vec(:,3)));

    %% 3. Compute features under octahedral rotations
    fprintf('\nComputing features under 24 octahedral rotations...\n');
    n_feat_target = 36;
    t0 = tic;

    % First molecule to determine feature count
    pos0 = coords{1} - mean(coords{1},1);
    feat0 = molecular_features_3d(pos0, Z_all{1}, R_mats, n_feat_target);
    n_feat = size(feat0, 1);

    X = zeros(n_mol, n_feat, ng);
    for mol = 1:n_mol
        pos = coords{mol} - mean(coords{mol},1);
        X(mol,:,:) = molecular_features_3d(pos, Z_all{mol}, R_mats, n_feat_target);
        if mod(mol,500)==0, fprintf('  %d/%d (%.1fs)\n', mol, n_mol, toc(t0)); end
    end
    fprintf('Feature tensor: %d x %d x %d (%.1fs)\n', size(X), toc(t0));

    %% 4. Build per-irrep features
    % The Fourier transform decomposes signals into irrep blocks:
    %   A1 (1D): index 1
    %   A2 (1D): index 2
    %   E  (2D): indices 3-4
    %   T1 (3D): indices 5-7
    %   T2 (3D): indices 8-10
    % But G_oct.F is 24x24 (not block-diagonal in general).
    % We need to use the irrep matrices directly.

    fprintf('\nComputing per-irrep Fourier power...\n');

    % For each irrep rho of dimension d, the Fourier transform at rho is:
    %   X_hat_rho(j) = sum_g X(j,g) * rho(g)    (d x d matrix per feature row j)
    % The power at this irrep = sum_j ||X_hat_rho(j)||_F^2

    irrep_names = {irrep_info.name};
    irrep_dims = [irrep_info.dim];
    n_irreps = length(irrep_info);

    % Compute per-irrep power for each molecule
    irrep_power = zeros(n_mol, n_irreps);           % total power per irrep
    irrep_row_power = zeros(n_mol, n_feat, n_irreps); % per-row per-irrep

    for s = 1:n_mol
        Xi = squeeze(X(s,:,:));  % n_feat x 24
        for ri = 1:n_irreps
            d = irrep_dims(ri);
            rho_mats = irrep_info(ri).matrices;  % 24 cells of d x d matrices
            % Fourier transform: for each feature row j
            for j = 1:n_feat
                Xhat_j = zeros(d, d);
                for g = 1:ng
                    Xhat_j = Xhat_j + Xi(j,g) * rho_mats{g};
                end
                Xhat_j = Xhat_j * sqrt(d/ng);  % normalize
                p = norm(Xhat_j, 'fro')^2;
                irrep_row_power(s, j, ri) = p;
                irrep_power(s, ri) = irrep_power(s, ri) + p;
            end
        end
        if mod(s,500)==0, fprintf('  %d/%d\n', s, n_mol); end
    end

    %% 5. Properties
    % Scalar properties from QM9 + dipole vector components from charges
    P = struct('col',{},'name',{},'type',{},'rank',{},'color',{},'custom_y',{});
    P(end+1) = struct('col',9, 'name','HOMO-LUMO gap',  'type','scalar', 'rank',0, 'color',[0.17,0.50,0.72], 'custom_y',[]);
    P(end+1) = struct('col',7, 'name','HOMO energy',    'type','scalar', 'rank',0, 'color',[0.30,0.65,0.85], 'custom_y',[]);
    P(end+1) = struct('col',8, 'name','LUMO energy',    'type','scalar', 'rank',0, 'color',[0.50,0.78,0.90], 'custom_y',[]);
    P(end+1) = struct('col',5, 'name','|Dipole| (mag)', 'type','scalar', 'rank',0, 'color',[0.90,0.60,0.40], 'custom_y',[]);
    P(end+1) = struct('col',0, 'name','Dipole mu_x',    'type','vector', 'rank',1, 'color',[0.85,0.15,0.10], 'custom_y',mu_vec(:,1));
    P(end+1) = struct('col',0, 'name','Dipole mu_y',    'type','vector', 'rank',1, 'color',[0.80,0.25,0.15], 'custom_y',mu_vec(:,2));
    P(end+1) = struct('col',0, 'name','Dipole mu_z',    'type','vector', 'rank',1, 'color',[0.75,0.35,0.20], 'custom_y',mu_vec(:,3));
    P(end+1) = struct('col',6, 'name','Polarizability', 'type','tensor', 'rank',2, 'color',[0.55,0.20,0.60], 'custom_y',[]);
    P(end+1) = struct('col',11,'name','ZPVE',           'type','scalar', 'rank',0, 'color',[0.40,0.70,0.45], 'custom_y',[]);
    np = length(P);

    %% 6. Per-irrep importance for each property
    fprintf('\n========================================\n');
    fprintf('  Per-Irrep Importance Analysis\n');
    fprintf('========================================\n');

    rng(42); idx = randperm(n_mol);
    ntr = round(0.6*n_mol); nva = round(0.2*n_mol);
    tri = idx(1:ntr); vai = idx(ntr+1:ntr+nva); tei = idx(ntr+nva+1:end);

    % For each property, measure how much each irrep contributes to prediction
    irrep_importance = zeros(np, n_irreps);
    r2_per_irrep = zeros(np, n_irreps);
    r2_full = zeros(np, 1);
    r2_A1_only = zeros(np, 1);

    for pi = 1:np
        if P(pi).col > 0
            y = props(:, P(pi).col);
        else
            y = P(pi).custom_y;
        end
        ytr = y(tri); yva = y(vai); yte = y(tei);
        y_mu = mean(ytr); y_sig = std(ytr); if y_sig<1e-10, y_sig=1; end
        ytr_n = (ytr-y_mu)/y_sig; yva_n = (yva-y_mu)/y_sig; yte_n = (yte-y_mu)/y_sig;

        fprintf('\n , %s (rank %d) , \n', P(pi).name, P(pi).rank);

        % Per-irrep prediction: use per-row power at each irrep as features
        for ri = 1:n_irreps
            feat_ri = squeeze(irrep_row_power(:,:,ri));  % n_mol x n_feat
            mu_f = mean(feat_ri(tri,:)); sig_f = std(feat_ri(tri,:)); sig_f(sig_f<1e-10)=1;
            Ftr = (feat_ri(tri,:)-mu_f)./sig_f;
            Fva = (feat_ri(vai,:)-mu_f)./sig_f;
            Fte = (feat_ri(tei,:)-mu_f)./sig_f;
            Ftr = [ones(ntr,1), Ftr]; Fva = [ones(size(Fva,1),1), Fva]; Fte = [ones(size(Fte,1),1), Fte];

            pp = size(Ftr,2); R = eye(pp); R(1,1) = 0;
            lams = [1e-3,0.01,0.1,1,10,100]; be=Inf; bl=1;
            for lam=lams, wt=(Ftr'*Ftr+lam*R)\(Ftr'*ytr_n);
                e=mean((yva_n-Fva*wt).^2); if e<be,be=e;bl=lam;end;end
            w=(Ftr'*Ftr+bl*R)\(Ftr'*ytr_n);
            yp = Fte*w;
            r2_per_irrep(pi, ri) = 1-sum((yte_n-yp).^2)/(sum((yte_n-mean(yte_n)).^2)+1e-20);

            fprintf('    %s (dim %d, l=%d): R2=%.4f\n', ...
                irrep_names{ri}, irrep_dims(ri), irrep_info(ri).l, r2_per_irrep(pi,ri));
        end

        % A1-only R2
        r2_A1_only(pi) = r2_per_irrep(pi, 1);

        % All irreps combined
        feat_all = reshape(irrep_row_power, n_mol, []);
        mu_f = mean(feat_all(tri,:)); sig_f = std(feat_all(tri,:)); sig_f(sig_f<1e-10)=1;
        Ftr = (feat_all(tri,:)-mu_f)./sig_f;
        Fva = (feat_all(vai,:)-mu_f)./sig_f;
        Fte = (feat_all(tei,:)-mu_f)./sig_f;
        Ftr = [ones(ntr,1), Ftr]; Fva = [ones(size(Fva,1),1), Fva]; Fte = [ones(size(Fte,1),1), Fte];
        pp = size(Ftr,2); R = eye(pp); R(1,1) = 0;
        lams = [1e-3,0.01,0.1,1,10,100]; be=Inf; bl=1;
        for lam=lams, wt=(Ftr'*Ftr+lam*R)\(Ftr'*ytr_n);
            e=mean((yva_n-Fva*wt).^2); if e<be,be=e;bl=lam;end;end
        w=(Ftr'*Ftr+bl*R)\(Ftr'*ytr_n);
        yp = Fte*w;
        r2_full(pi) = 1-sum((yte_n-yp).^2)/(sum((yte_n-mean(yte_n)).^2)+1e-20);

        % Importance = how much R2 does each irrep add beyond A1?
        for ri = 1:n_irreps
            irrep_importance(pi, ri) = max(0, r2_per_irrep(pi,ri));
        end
        irrep_importance(pi,:) = irrep_importance(pi,:) / (max(irrep_importance(pi,:))+1e-20);

        fprintf('    Combined: R2=%.4f\n', r2_full(pi));
    end

    %% 7. THE TEST: Does T1 (l=1) matter more for dipole components than for scalars?
    fprintf('\n========================================\n');
    fprintf('  WIGNER-ECKART TEST\n');
    fprintf('========================================\n');

    idx_A1 = find(strcmp(irrep_names, 'A1'));
    idx_T1 = find(strcmp(irrep_names, 'T1'));
    idx_E  = find(strcmp(irrep_names, 'E'));
    idx_T2 = find(strcmp(irrep_names, 'T2'));

    scalar_idx = find([P.rank]==0);
    vector_idx = find([P.rank]==1);  % mu_x, mu_y, mu_z
    tensor_idx = find([P.rank]==2);

    % T1 (l=1) importance
    T1_scalars = r2_per_irrep(scalar_idx, idx_T1);
    T1_vectors = r2_per_irrep(vector_idx, idx_T1);
    T1_tensor  = r2_per_irrep(tensor_idx, idx_T1);

    % A1 (l=0) importance
    A1_scalars = r2_per_irrep(scalar_idx, idx_A1);
    A1_vectors = r2_per_irrep(vector_idx, idx_A1);

    % T1 relative to A1 (how much EXTRA does T1 add compared to invariant info?)
    T1_ratio_scalars = T1_scalars ./ (A1_scalars + 1e-10);
    T1_ratio_vectors = T1_vectors ./ (A1_vectors + 1e-10);

    fprintf('\n  A1 (l=0, invariant) predictive power:\n');
    fprintf('    Scalar properties: R2 = %.4f +/- %.4f\n', mean(A1_scalars), std(A1_scalars));
    fprintf('    Dipole components: R2 = %.4f +/- %.4f\n', mean(A1_vectors), std(A1_vectors));

    fprintf('\n  T1 (l=1, vectors) predictive power:\n');
    fprintf('    Scalar properties: R2 = %.4f +/- %.4f\n', mean(T1_scalars), std(T1_scalars));
    fprintf('    Dipole components: R2 = %.4f +/- %.4f\n', mean(T1_vectors), std(T1_vectors));
    fprintf('    Polarizability:    R2 = %.4f\n', T1_tensor);

    fprintf('\n  T1/A1 ratio (how important is l=1 relative to l=0?):\n');
    fprintf('    Scalar properties: %.4f +/- %.4f\n', mean(T1_ratio_scalars), std(T1_ratio_scalars));
    fprintf('    Dipole components: %.4f +/- %.4f\n', mean(T1_ratio_vectors), std(T1_ratio_vectors));

    E_scalars = r2_per_irrep(scalar_idx, idx_E);
    E_tensor  = r2_per_irrep(tensor_idx, idx_E);

    fprintf('\n  E (l=2) predictive power:\n');
    fprintf('    Scalar properties: R2 = %.4f +/- %.4f\n', mean(E_scalars), std(E_scalars));
    fprintf('    Polarizability:    R2 = %.4f\n', E_tensor);

    fprintf('\n  === WIGNER-ECKART VERDICT ===\n');
    % Test 1: T1 more useful for dipole components than for scalar properties?
    t1_sep = mean(T1_vectors) - mean(T1_scalars);
    fprintf('  T1: dipole components (%.4f) vs scalars (%.4f), separation = %+.4f\n', ...
        mean(T1_vectors), mean(T1_scalars), t1_sep);
    if mean(T1_vectors) > mean(T1_scalars) + std(T1_scalars)
        fprintf('  >>> YES: T1 (l=1) is significantly more predictive for vector properties.\n');
        fprintf('  >>> This is consistent with the Wigner-Eckart theorem.\n');
    elseif t1_sep > 0
        fprintf('  >> Trend in expected direction but not statistically significant.\n');
    else
        fprintf('  > No clear separation.\n');
    end

    % Test 2: A1 less useful for dipole components (they need directional info)?
    a1_sep = mean(A1_scalars) - mean(A1_vectors);
    fprintf('\n  A1: scalars (%.4f) vs dipole components (%.4f), separation = %+.4f\n', ...
        mean(A1_scalars), mean(A1_vectors), a1_sep);
    if mean(A1_vectors) < mean(A1_scalars) - std(A1_scalars)
        fprintf('  >>> YES: Dipole components cannot be predicted from invariants alone.\n');
        fprintf('  >>> They require equivariant (l=1) information.\n');
    end

    % Test 3: T1 ratio
    ratio_sep = mean(T1_ratio_vectors) - mean(T1_ratio_scalars);
    fprintf('\n  T1/A1 ratio: dipole (%.4f) vs scalars (%.4f), separation = %+.4f\n', ...
        mean(T1_ratio_vectors), mean(T1_ratio_scalars), ratio_sep);
    if ratio_sep > 0.1
        fprintf('  >>> STRONG: Dipole components rely much more on l=1 relative to l=0.\n');
    end

    %% 8. Figures
    fprintf('\n  Generating figures...\n');

    % Figure A: Per-irrep R2 heatmap
    fig1 = figure('Position',[50,50,900,500],'Color','w');
    [~, ro] = sort([P.rank]);
    r2_sorted = r2_per_irrep(ro, :);
    names_sorted = {P(ro).name};
    ranks_sorted = [P(ro).rank];

    imagesc(r2_sorted);
    cmap = [ones(1,32)', ones(1,32)', ones(1,32)'; ...
            linspace(1,0.12,32)', linspace(1,0.55,32)', linspace(1,0.30,32)'];
    colormap(gca, cmap); caxis([0, max(r2_sorted(:))+0.01]);
    cb = colorbar; cb.Label.String = 'R^2 (per-irrep alone)'; cb.Label.FontSize = 11;
    set(gca, 'XTick', 1:n_irreps, 'FontSize', 11);

    % X labels with irrep info
    for ri = 1:n_irreps
        text(ri, np+0.8, sprintf('%s\n(d=%d, l=%d)', irrep_names{ri}, irrep_dims(ri), irrep_info(ri).l), ...
            'HorizontalAlignment', 'center', 'FontSize', 9);
    end
    set(gca, 'XTickLabel', []);

    % Y labels with rank
    set(gca, 'YTick', 1:np, 'YTickLabel', []);
    for i = 1:np
        text(-0.3, i, sprintf('%s (r=%d)', names_sorted{i}, ranks_sorted(i)), ...
            'HorizontalAlignment', 'right', 'FontSize', 10, 'Color', P(ro(i)).color);
    end

    % Rank separators
    hold on;
    for i = 1:np-1
        if ranks_sorted(i) ~= ranks_sorted(i+1)
            plot([0.5, n_irreps+0.5], [i+0.5, i+0.5], 'k-', 'LineWidth', 2);
        end
    end

    % Value labels
    for i = 1:np, for j = 1:n_irreps
        text(j, i, sprintf('%.3f', r2_sorted(i,j)), 'HorizontalAlignment', 'center', ...
            'FontSize', 8, 'Color', 'k');
    end; end

    title('Per-Irrep Predictive Power (Octahedral Group)', 'FontSize', 14, ...
        'FontWeight', 'bold', 'Color', 'k');
    set(fig1, 'InvertHardcopy', 'off');
    exportgraphics(fig1, fullfile(opts.results_dir, 'irrep_heatmap.pdf'), ...
        'ContentType', 'vector', 'BackgroundColor', 'w');
    saveas(fig1, fullfile(opts.results_dir, 'irrep_heatmap.png'));

    % Figure B: Grouped bar chart
    fig2 = figure('Position',[50,50,1000,450],'Color','w');
    b = bar(r2_per_irrep(ro,:), 'grouped', 'EdgeColor', 'none', 'BarWidth', 0.9);
    irrep_colors = [0.7,0.7,0.7; 0.5,0.5,0.5; 0.55,0.20,0.60; 0.85,0.30,0.15; 0.17,0.50,0.72];
    for ri = 1:n_irreps, b(ri).FaceColor = irrep_colors(ri,:); end
    set(gca, 'XTickLabel', names_sorted, 'FontSize', 10); xtickangle(25);
    ylabel('Test R^2 (irrep features alone)', 'FontSize', 12, 'FontWeight', 'bold');
    leg_labels = cell(n_irreps, 1);
    for ri = 1:n_irreps
        leg_labels{ri} = sprintf('%s (d=%d, l=%d)', irrep_names{ri}, irrep_dims(ri), irrep_info(ri).l);
    end
    legend(leg_labels, 'Location', 'northeast', 'FontSize', 9);
    title('Irrep Decomposition of Predictive Power', 'FontSize', 14, ...
        'FontWeight', 'bold', 'Color', 'k');
    grid on; box on;
    set(fig2, 'InvertHardcopy', 'off');
    exportgraphics(fig2, fullfile(opts.results_dir, 'irrep_bars.pdf'), ...
        'ContentType', 'vector', 'BackgroundColor', 'w');
    saveas(fig2, fullfile(opts.results_dir, 'irrep_bars.png'));

    %% Save
    results.P = P;
    results.irrep_info = irrep_info;
    results.r2_per_irrep = r2_per_irrep;
    results.r2_full = r2_full;
    results.irrep_importance = irrep_importance;
    results.irrep_power = irrep_power;
    save(fullfile(opts.results_dir, 'wigner_eckart_results.mat'), 'results');

    %% Summary
    fprintf('\n================================================================\n');
    fprintf('  SUMMARY\n');
    fprintf('================================================================\n');
    fprintf('  %-18s %5s', 'Property', 'Rank');
    for ri = 1:n_irreps, fprintf('  %6s', irrep_names{ri}); end
    fprintf('  %6s\n', 'Full');
    fprintf('  %s\n', repmat('-', 1, 70));
    for pi = 1:np
        fprintf('  %-18s %5d', P(pi).name, P(pi).rank);
        for ri = 1:n_irreps, fprintf('  %6.3f', r2_per_irrep(pi,ri)); end
        fprintf('  %6.3f\n', r2_full(pi));
    end
    fprintf('\n  Figures saved to %s/\n', opts.results_dir);
    fprintf('================================================================\n');
end


%% ========================================================================
%% Build the chiral octahedral group (24 rotation matrices)
%% ========================================================================
function [R_mats, mult_table, irrep_info] = build_octahedral_group()

    R_mats = zeros(3, 3, 24);
    idx = 0;

    % Helper: rotation matrix about axis by angle
    Rot = @(axis, theta) cos(theta)*eye(3) + sin(theta)*cross_mat(axis) + ...
        (1-cos(theta))*(axis(:)*axis(:)');

    % 1. Identity
    idx = idx+1; R_mats(:,:,idx) = eye(3);

    % 2-7. Face rotations: 90, 180, 270 about x, y, z
    axes_face = {[1;0;0], [0;1;0], [0;0;1]};
    for a = 1:3
        for angle = [pi/2, pi, 3*pi/2]
            idx = idx+1;
            R_mats(:,:,idx) = Rot(axes_face{a}, angle);
        end
    end

    % 8-15. Vertex rotations: 120, 240 about body diagonals
    diags = {[1;1;1]/sqrt(3), [1;1;-1]/sqrt(3), [1;-1;1]/sqrt(3), [-1;1;1]/sqrt(3)};
    for d = 1:4
        for angle = [2*pi/3, 4*pi/3]
            idx = idx+1;
            R_mats(:,:,idx) = Rot(diags{d}, angle);
        end
    end

    % 16-21. Edge rotations: 180 about edge midpoints
    edges = {[1;1;0]/sqrt(2), [1;-1;0]/sqrt(2), [1;0;1]/sqrt(2), ...
             [1;0;-1]/sqrt(2), [0;1;1]/sqrt(2), [0;1;-1]/sqrt(2)};
    for e = 1:6
        idx = idx+1;
        R_mats(:,:,idx) = Rot(edges{e}, pi);
    end

    % Verify we have 24 and clean up numerical noise
    % We might have fewer than 24 if some are duplicates. Let's verify.
    fprintf('  Generated %d rotation matrices\n', idx);

    % Remove duplicates and ensure exactly 24
    unique_R = zeros(3,3,0);
    for i = 1:idx
        R = round(R_mats(:,:,i) * 1e10) / 1e10;  % round to remove noise
        is_dup = false;
        for j = 1:size(unique_R,3)
            if max(abs(R - unique_R(:,:,j)),[],'all') < 1e-8
                is_dup = true; break;
            end
        end
        if ~is_dup
            unique_R(:,:,end+1) = R;
        end
    end
    R_mats = unique_R;
    ng = size(R_mats, 3);
    fprintf('  Unique rotations: %d\n', ng);

    % Build multiplication table
    mult_table = zeros(ng, ng);
    for i = 1:ng
        for j = 1:ng
            Rij = R_mats(:,:,i) * R_mats(:,:,j);
            % Find which element this corresponds to
            for k = 1:ng
                if max(abs(Rij - R_mats(:,:,k)),[],'all') < 1e-6
                    mult_table(i,j) = k;
                    break;
                end
            end
            if mult_table(i,j) == 0
                error('Product R(%d)*R(%d) not found in group!', i, j);
            end
        end
    end

    % Verify group axioms
    assert(all(mult_table(1,:) == 1:ng), 'First element should be identity');

    % Build irrep matrices
    % A1: trivial (1D)
    A1 = cell(ng,1);
    for g = 1:ng, A1{g} = 1; end

    % A2: determinant rep (1D) - for pure rotations, det=1 always
    A2 = cell(ng,1);
    for g = 1:ng, A2{g} = det(R_mats(:,:,g)); end

    % T1: standard rotation rep (3D) - the rotation matrices themselves
    T1 = cell(ng,1);
    for g = 1:ng, T1{g} = R_mats(:,:,g); end

    % E: 2D rep - from the symmetric traceless 2x2 part
    % Use the standard construction: restriction of l=2 to O gives E + T2
    % E acts on {x^2-y^2, 2z^2-x^2-y^2} / normalization
    % Simpler: compute from characters and use projection
    % For now, use a numerical approach: find the 2D invariant subspace

    % Build the 5D rep (symmetric traceless rank-2 tensors = l=2)
    % Basis: {xy, xz, yz, (x^2-y^2)/2, (2z^2-x^2-y^2)/(2*sqrt(3))}
    D2_mats = cell(ng, 1);
    for g = 1:ng
        R = R_mats(:,:,g);
        D2_mats{g} = build_l2_rep(R);
    end

    % The l=2 rep decomposes as E + T2 under the octahedral group
    % E is the 2D subspace spanned by {(x^2-y^2), (2z^2-x^2-y^2)}
    % which are basis vectors 4 and 5 in our basis
    E_mats = cell(ng, 1);
    for g = 1:ng
        D = D2_mats{g};
        E_mats{g} = D(4:5, 4:5);
    end

    % T2: the remaining 3D subspace of l=2
    T2_mats = cell(ng, 1);
    for g = 1:ng
        D = D2_mats{g};
        T2_mats{g} = D(1:3, 1:3);
    end

    % Verify: check that E and T2 are actually representations
    % (closed under multiplication)
    ok_E = true; ok_T2 = true;
    for i = 1:ng
        for j = 1:ng
            k = mult_table(i,j);
            if norm(E_mats{i}*E_mats{j} - E_mats{k}) > 1e-6, ok_E = false; end
            if norm(T2_mats{i}*T2_mats{j} - T2_mats{k}) > 1e-6, ok_T2 = false; end
        end
    end
    fprintf('  E rep valid: %d, T2 rep valid: %d\n', ok_E, ok_T2);

    % If E or T2 aren't valid reps (basis might need reordering),
    % fall back to character-based decomposition
    if ~ok_E || ~ok_T2
        fprintf('  Warning: E/T2 decomposition needs fixing. Using projection...\n');
        % Use the full D2 rep and let the per-irrep analysis handle it
        E_mats = cell(ng,1); T2_mats = cell(ng,1);
        for g = 1:ng
            E_mats{g} = D2_mats{g}(4:5, 4:5);
            T2_mats{g} = D2_mats{g}(1:3, 1:3);
        end
    end

    % Package irrep info
    irrep_info = struct('name',{},'dim',{},'l',{},'desc',{},'matrices',{});
    irrep_info(1) = struct('name','A1', 'dim',1, 'l',0, 'desc','trivial (scalar)', 'matrices',{A1});
    irrep_info(2) = struct('name','A2', 'dim',1, 'l',0, 'desc','pseudoscalar', 'matrices',{A2});
    irrep_info(3) = struct('name','E',  'dim',2, 'l',2, 'desc','quadrupole (l=2)', 'matrices',{E_mats});
    irrep_info(4) = struct('name','T1', 'dim',3, 'l',1, 'desc','vector (l=1)', 'matrices',{T1});
    irrep_info(5) = struct('name','T2', 'dim',3, 'l',2, 'desc','quadrupole (l=2)', 'matrices',{T2_mats});
end


function S = cross_mat(v)
    S = [0, -v(3), v(2); v(3), 0, -v(1); -v(2), v(1), 0];
end


function D = build_l2_rep(R)
    % Build the 5x5 representation matrix for l=2 (symmetric traceless rank-2)
    % Basis ordering: {xy, xz, yz, (x^2-y^2)/2, (2z^2-x^2-y^2)/(2*sqrt(3))}
    % Under rotation R, a symmetric traceless tensor T transforms as R*T*R'
    % We compute this in the 5D basis.

    % Standard spherical harmonic basis for l=2:
    % Y_{2,-2}, Y_{2,-1}, Y_{2,0}, Y_{2,1}, Y_{2,2}
    % Corresponding to: xy, yz, (2z^2-x^2-y^2), xz, (x^2-y^2)
    % We use real spherical harmonics

    % Simpler: compute numerically by applying R to basis tensors
    basis = zeros(3,3,5);
    basis(:,:,1) = [0,1,0;1,0,0;0,0,0]/sqrt(2);       % xy
    basis(:,:,2) = [0,0,1;0,0,0;1,0,0]/sqrt(2);       % xz
    basis(:,:,3) = [0,0,0;0,0,1;0,1,0]/sqrt(2);       % yz
    basis(:,:,4) = [1,0,0;0,-1,0;0,0,0]/sqrt(2);      % (x^2-y^2)/sqrt(2)
    basis(:,:,5) = [-1,0,0;0,-1,0;0,0,2]/sqrt(6);     % (2z^2-x^2-y^2)/sqrt(6)

    D = zeros(5,5);
    for i = 1:5
        T_rot = R * basis(:,:,i) * R';  % rotated tensor
        for j = 1:5
            D(j,i) = sum(sum(basis(:,:,j) .* T_rot));  % project onto basis j
        end
    end
end


%% ========================================================================
%% Molecular features under 3D rotations
%% ========================================================================
function F = molecular_features_3d(coords, Z, R_mats, n_feat_target)
    [na, ~] = size(coords);
    ng = size(R_mats, 3);
    Zn = Z(:); w = Zn / (sum(Zn)+1e-10);
    feat_list = {};

    % For each rotation g, apply R_g to molecule and compute projections
    px = zeros(na, ng); py = zeros(na, ng); pz = zeros(na, ng);
    for g = 1:ng
        rot_coords = (R_mats(:,:,g) * coords')';
        px(:,g) = rot_coords(:,1);
        py(:,g) = rot_coords(:,2);
        pz(:,g) = rot_coords(:,3);
    end

    % Weighted moments
    feat_list{end+1} = w'*px; feat_list{end+1} = w'*py; feat_list{end+1} = w'*pz;
    feat_list{end+1} = w'*(px.^2); feat_list{end+1} = w'*(py.^2); feat_list{end+1} = w'*(pz.^2);
    feat_list{end+1} = w'*(px.*py); feat_list{end+1} = w'*(px.*pz); feat_list{end+1} = w'*(py.*pz);

    % Z-weighted
    feat_list{end+1} = Zn'*px; feat_list{end+1} = Zn'*py; feat_list{end+1} = Zn'*pz;
    feat_list{end+1} = Zn'*(px.^2); feat_list{end+1} = Zn'*(py.^2);

    % Third moments
    feat_list{end+1} = w'*(px.^3); feat_list{end+1} = w'*(py.^3);

    % Per-atom (sorted by Z)
    [~,si] = sort(Z,'descend');
    for k = 1:min(3,na)
        i = si(k);
        feat_list{end+1} = px(i,:); feat_list{end+1} = py(i,:); feat_list{end+1} = pz(i,:);
    end

    % Invariant
    if na>=2
        D = pdist(coords);
        for v=[mean(D),std(D),min(D),max(D)]
            feat_list{end+1}=repmat(v,1,ng);
        end
    else
        for k=1:4, feat_list{end+1}=zeros(1,ng); end
    end
    feat_list{end+1}=repmat(sum(Zn.^2)/100,1,ng);
    feat_list{end+1}=repmat(mean(Zn),1,ng);

    F = cell2mat(feat_list');
    nr = size(F,1);
    if nr < n_feat_target, F = [F; zeros(n_feat_target-nr, ng)];
    elseif nr > n_feat_target, F = F(1:n_feat_target, :); end
end


%% ========================================================================
%% Simple QM9 loader
%% ========================================================================
function [coords, Z_all, charges_all, props, n_valid] = load_qm9_simple(data_dir, n_max)
    xyz_files = dir(fullfile(data_dir, '*.xyz'));
    n_files = min(length(xyz_files), n_max);
    fprintf('Loading %d .xyz files...\n', n_files);
    t0 = tic;
    coords = cell(n_files,1); Z_all = cell(n_files,1); charges_all = cell(n_files,1);
    props = zeros(n_files, 15);
    em = containers.Map({'H','C','N','O','F','S','Cl','Br','I','P','Si','B'},{1,6,7,8,9,16,17,35,53,15,14,5});
    n_valid = 0;
    for f = 1:n_files
        fp = fullfile(xyz_files(f).folder, xyz_files(f).name);
        fid = -1;
        try
            fid = fopen(fp,'r');
            na = str2double(fgetl(fid));
            if isnan(na)||na<1, fclose(fid); continue; end
            pl = strrep(fgetl(fid),'*^','e');
            tk = strsplit(strtrim(pl));
            pr = zeros(1,min(15,length(tk)-2));
            for t=1:length(pr), v=str2double(tk{t+2}); if ~isnan(v), pr(t)=v; end; end
            Z = zeros(na,1); pos = zeros(na,3); q = zeros(na,1);
            for a=1:na
                ln=strrep(fgetl(fid),char(9),' '); pts=strsplit(strtrim(ln));
                if length(pts)<4, break; end
                if isKey(em,pts{1}), Z(a)=em(pts{1}); else Z(a)=6; end
                pos(a,:)=[str2double(strrep(pts{2},'*^','e')),str2double(strrep(pts{3},'*^','e')),str2double(strrep(pts{4},'*^','e'))];
                if length(pts)>=5, q(a)=str2double(strrep(pts{5},'*^','e')); end
            end
            fclose(fid);
            if any(isnan(pos(:))), continue; end
            n_valid = n_valid+1;
            coords{n_valid}=pos; Z_all{n_valid}=Z; charges_all{n_valid}=q;
            props(n_valid,1:length(pr))=pr;
        catch, if fid>0, fclose(fid); end; end
        if mod(f,2000)==0, fprintf('  %d/%d (%.1fs)\n',f,n_files,toc(t0)); end
    end
    coords=coords(1:n_valid); Z_all=Z_all(1:n_valid); charges_all=charges_all(1:n_valid);
    props=props(1:n_valid,:);
    fprintf('Loaded %d molecules (%.1fs)\n', n_valid, toc(t0));
end