%% ========================================================================
%% product_group_experiment.m
%% Demonstrate the compositional advantage of star_G over product groups
%%
%% DESIGN:
%%   G1 = Z_n1 (angular measurement in xy-plane: shifts under z-rotation)
%%   G2 = Z_n2 (axial measurement along z: shifts under z-translation)
%%   G  = G1 x G2 (commuting product group, order n1*n2)
%%
%%   The two group actions COMMUTE because:
%%     R_z changes (x,y) but not z
%%     T_z changes z but not (x,y)
%%
%%   Features at (g1,g2) include:
%%     - angular projections (shift in g1, constant in g2)
%%     - axial periodic functions (constant in g1, shift in g2)
%%     - coupled products (shift in both)
%%
%%   Target depends on 2D Fourier content. The COUPLED (freq1>0, freq2>0)
%%   terms dominate, so only the product group can capture the full signal.
%%
%% LH & Claude 2026
%% ========================================================================

classdef product_group_experiment < handle
    properties
        n1; n2; n_molecules
        coords; atomic_numbers; properties_mat
        G_prod; G1; G2
        X_tensor; n_feat_per_rot
        is_real_data; results
        L_axial   % characteristic length for axial embedding
    end

    methods
        function obj = product_group_experiment(n1, n2, varargin)
            p = inputParser;
            addRequired(p,'n1'); addRequired(p,'n2');
            addParameter(p,'n_molecules',1000);
            parse(p, n1, n2, varargin{:});
            obj.n1 = n1; obj.n2 = n2;
            obj.n_molecules = 0;
            obj.L_axial = 6;  % z-range for periodic embedding
            obj.is_real_data = false;

            obj.G1 = StarGAlgebra('cyclic', n1);
            obj.G2 = StarGAlgebra('cyclic', n2);
            obj.G_prod = StarGAlgebra('product', {obj.G1, obj.G2});
            fprintf('Product Group: G = Z_%d x Z_%d (order %d)\n', n1, n2, obj.G_prod.n);
            fprintf('  G1 = Z_%d (xy-angular, shifts under z-rotation)\n', n1);
            fprintf('  G2 = Z_%d (z-axial, shifts under z-translation)\n', n2);
        end

        %% DATA ============================================================
        function obj = load_data(obj, data_dir, n_max)
            if nargin < 3, n_max = 1000; end
            if nargin >= 2 && ~isempty(data_dir) && isfolder(data_dir)
                xyz_files = dir(fullfile(data_dir, '*.xyz'));
                if ~isempty(xyz_files)
                    obj = obj.load_xyz_files(xyz_files, n_max);
                    obj.is_real_data = true; return;
                end
            end
            obj = obj.generate_synthetic(min(n_max, 1000));
        end

        function obj = load_xyz_files(obj, xyz_files, n_max)
            n_files = min(length(xyz_files), n_max);
            fprintf('Loading %d .xyz files...\n', n_files); t0=tic;
            obj.coords=cell(n_files,1); obj.atomic_numbers=cell(n_files,1);
            obj.properties_mat=zeros(n_files,15);
            em=containers.Map({'H','C','N','O','F','S','Cl','Br','I','P','Si','B'},{1,6,7,8,9,16,17,35,53,15,14,5});
            valid=0;
            for f=1:n_files
                fp=fullfile(xyz_files(f).folder,xyz_files(f).name); fid=-1;
                try
                    fid=fopen(fp,'r'); na=str2double(fgetl(fid));
                    if isnan(na)||na<1, fclose(fid); continue; end
                    pl=strrep(fgetl(fid),'*^','e'); tk=strsplit(strtrim(pl));
                    pr=zeros(1,min(15,length(tk)-2));
                    for t=1:length(pr), v=str2double(tk{t+2}); if ~isnan(v), pr(t)=v; end; end
                    Z=zeros(na,1); pos=zeros(na,3);
                    for a=1:na
                        ln=strrep(fgetl(fid),char(9),' '); pts=strsplit(strtrim(ln));
                        if length(pts)<4, break; end
                        if isKey(em,pts{1}), Z(a)=em(pts{1}); else, Z(a)=6; end
                        pos(a,:)=[str2double(strrep(pts{2},'*^','e')),str2double(strrep(pts{3},'*^','e')),str2double(strrep(pts{4},'*^','e'))];
                    end
                    fclose(fid); if any(isnan(pos(:))), continue; end
                    valid=valid+1; obj.coords{valid}=pos; obj.atomic_numbers{valid}=Z;
                    obj.properties_mat(valid,1:length(pr))=pr;
                catch, if fid>0, fclose(fid); end; end
                if mod(f,2000)==0, fprintf('  %d/%d (%.1fs)\n',f,n_files,toc(t0)); end
            end
            obj.coords=obj.coords(1:valid); obj.atomic_numbers=obj.atomic_numbers(1:valid);
            obj.properties_mat=obj.properties_mat(1:valid,:); obj.n_molecules=valid;
            fprintf('Loaded %d molecules (%.1fs)\n', valid, toc(t0));
        end

        function obj = generate_synthetic(obj, n_mol)
            fprintf('Generating %d synthetic molecules...\n', n_mol);
            rng(42,'twister'); atom_types=[1,6,7,8,9];
            obj.coords=cell(n_mol,1); obj.atomic_numbers=cell(n_mol,1);
            obj.properties_mat=zeros(n_mol,15);
            obj.is_real_data = false;
            n1=obj.n1; n2=obj.n2;
            ang1=(0:n1-1)*2*pi/n1; ang2=(0:n2-1)*2*pi/n2;

            % Product group Fourier matrix
            F_prod = obj.G_prod.F;

            for i=1:n_mol
                na=randi([4,12]); pos=randn(na,3)*1.5; pos=pos-mean(pos,1);
                Z=atom_types(randi(5,na,1))';
                obj.coords{i}=pos; obj.atomic_numbers{i}=Z;

                % Compute features for this molecule
                F_all = obj.build_features(pos, Z, ang1, ang2, 36);

                % Row 7 = first coupled row: w'*(p1.*az)
                % Row 8 = second coupled row: Z'*(p1.*az)
                % 2D Fourier via kron(F1,F2)
                r7_hat = F_all(7,:) * F_prod;
                r8_hat = F_all(8,:) * F_prod;
                pw7 = abs(r7_hat).^2;
                pw8 = abs(r8_hat).^2;

                % 2D freq indices: (f1,f2) -> (f2-1)*n1 + f1 (1-indexed)
                i22 = (2-1)*n1+2;  % coupled (1,1) in 0-indexed
                i23 = (3-1)*n1+2;  % coupled (1,2)
                i32 = (2-1)*n1+3;  % coupled (2,1)
                i20 = 2;            % axis-1 only
                i02 = (2-1)*n1+1;   % axis-2 only

                % Target: LINEAR in per-row 2D Fourier power
                % Coupled terms dominate (coeff 13) vs single-axis (coeff 2)
                y_c = 5.0*pw7(i22) + 3.0*pw7(i23) + 3.0*pw7(i32) + 2.0*pw8(i22);
                y_1 = 1.0*pw7(i20);
                y_2 = 1.0*pw7(i02);
                obj.properties_mat(i,8) = y_c + y_1 + y_2;
            end
            obj.n_molecules=n_mol;
            y = obj.properties_mat(1:n_mol,8);
            fprintf('Target: linear in per-row 2D Fourier power of coupled rows\n');
            fprintf('  Coupled coeff=13 vs single-axis=2 (6.5x ratio)\n');
            fprintf('  y range: [%.2f, %.2f], std=%.2f\n', min(y), max(y), std(y));
        end

        %% FEATURE COMPUTATION ============================================
        function obj = compute_features(obj, varargin)
            p=inputParser; addParameter(p,'n_feat',36); parse(p,varargin{:});
            nft=p.Results.n_feat; n_mol=obj.n_molecules;
            n1=obj.n1; n2=obj.n2; ng=n1*n2;
            ang1=(0:n1-1)*2*pi/n1; ang2=(0:n2-1)*2*pi/n2;

            pos0=obj.coords{1}-mean(obj.coords{1},1);
            feat0=obj.build_features(pos0,obj.atomic_numbers{1},ang1,ang2,nft);
            nf=size(feat0,1); obj.n_feat_per_rot=nf;

            X=zeros(n_mol,nf,ng); t0=tic;
            fprintf('Computing product features: %d mol x %d feat x %d group...\n',n_mol,nf,ng);
            for mol=1:n_mol
                pos=obj.coords{mol}-mean(obj.coords{mol},1);
                X(mol,:,:)=obj.build_features(pos,obj.atomic_numbers{mol},ang1,ang2,nft);
                if mod(mol,500)==0, fprintf('  %d/%d (%.1fs)\n',mol,n_mol,toc(t0)); end
            end
            obj.X_tensor=X;
            fprintf('Feature tensor: %d x %d x %d (%.1fs)\n',size(X),toc(t0));
            obj.verify_product_structure();
        end

        function F = build_features(obj, coords, Z, ang1, ang2, nft)
            % Build feature vector for each (g1,g2) pair.
            % Features are arranged so that:
            %   - Angular features shift cyclically with g1 (x,y projections)
            %   - Axial features shift cyclically with g2 (z periodic embedding)
            %   - Coupled features shift in both dimensions
            %   - Invariant features are constant

            n1=length(ang1); n2=length(ang2); ng=n1*n2;
            na=size(coords,1); Zn=Z(:); w=Zn/(sum(Zn)+1e-10);
            L=obj.L_axial;
            feat_list={};

            % ---- ANGULAR features (depend on g1 only) ----
            % For each g1: project onto rotating xy-basis
            % Replicated across g2 (constant in axial dimension)
            for g1=1:n1
                ca=cos(ang1(g1)); sa=sin(ang1(g1));
                p1_g1=coords(:,1)*ca+coords(:,2)*sa;
                ang_vals(g1,:)=[w'*p1_g1, Zn'*p1_g1, w'*(p1_g1.^2)];
            end
            % ang_vals is n1 x 3. Replicate across g2.
            for f=1:3
                row=zeros(1,ng);
                for g2=1:n2
                    row((1:n1)+(g2-1)*n1)=ang_vals(:,f)';
                end
                feat_list{end+1}=row;
            end

            % ---- AXIAL features (depend on g2 only) ----
            % Periodic embedding of z-coordinates
            for g2=1:n2
                ax_vals(g2,:)=[w'*cos(2*pi*coords(:,3)/L-ang2(g2)), ...
                               Zn'*cos(2*pi*coords(:,3)/L-ang2(g2)), ...
                               w'*sin(2*pi*coords(:,3)/L-ang2(g2))];
            end
            % ax_vals is n2 x 3. Replicate across g1.
            for f=1:3
                row=zeros(1,ng);
                for g2=1:n2
                    row((1:n1)+(g2-1)*n1)=ax_vals(g2,f);
                end
                feat_list{end+1}=row;
            end

            % ---- COUPLED features (depend on both g1 AND g2) ----
            % Product of angular and axial signals
            for g1=1:n1
                ca=cos(ang1(g1)); sa=sin(ang1(g1));
                p1_g1=coords(:,1)*ca+coords(:,2)*sa;
                for g2=1:n2
                    az_g2=cos(2*pi*coords(:,3)/L-ang2(g2));
                    % Linear index in product group
                    g_lin=(g2-1)*n1+g1;
                    coupled_vals(g1,g2,:)=[w'*(p1_g1.*az_g2), ...
                                           Zn'*(p1_g1.*az_g2), ...
                                           w'*(p1_g1.^2.*az_g2), ...
                                           w'*(p1_g1.*az_g2.^2)];
                end
            end
            % coupled_vals is n1 x n2 x 4
            for f=1:4
                row=zeros(1,ng);
                for g2=1:n2
                    row((1:n1)+(g2-1)*n1)=coupled_vals(:,g2,f)';
                end
                feat_list{end+1}=row;
            end

            % ---- INVARIANT features (constant across all g) ----
            if na>=2
                D=pdist(coords);
                for v=[mean(D),std(D),min(D),max(D)]
                    feat_list{end+1}=repmat(v,1,ng);
                end
            else
                for k=1:4, feat_list{end+1}=zeros(1,ng); end
            end
            feat_list{end+1}=repmat(sum(Zn.^2)/100,1,ng);
            feat_list{end+1}=repmat(mean(Zn),1,ng);
            feat_list{end+1}=repmat(na,1,ng);

            % Higher-order angular (for richer features)
            for g1=1:n1
                ca=cos(ang1(g1)); sa=sin(ang1(g1));
                p1_g1=coords(:,1)*ca+coords(:,2)*sa;
                p2_g1=-coords(:,1)*sa+coords(:,2)*ca;
                ang_ho(g1,:)=[w'*(p1_g1.^3), w'*(p1_g1.*p2_g1), Zn'*(p1_g1.^2)];
            end
            for f=1:3
                row=zeros(1,ng);
                for g2=1:n2
                    row((1:n1)+(g2-1)*n1)=ang_ho(:,f)';
                end
                feat_list{end+1}=row;
            end

            % Assemble
            F=cell2mat(feat_list');
            nr=size(F,1);
            if nr<nft, F=[F;zeros(nft-nr,ng)];
            elseif nr>nft, F=F(1:nft,:); end
        end

        function verify_product_structure(obj)
            if obj.n_molecules<1, return; end
            mi=min(3,obj.n_molecules);
            pos=obj.coords{mi}-mean(obj.coords{mi},1);
            Z=obj.atomic_numbers{mi};
            ang1=(0:obj.n1-1)*2*pi/obj.n1; ang2=(0:obj.n2-1)*2*pi/obj.n2;

            F0=obj.build_features(pos,Z,ang1,ang2,obj.n_feat_per_rot);
            F0_3d=reshape(F0,[],obj.n1,obj.n2);

            % Test g1-shift: rotate about z by 2pi/n1
            th1=2*pi/obj.n1;
            R1=[cos(th1),-sin(th1),0;sin(th1),cos(th1),0;0,0,1];
            Fr1=obj.build_features((R1*pos')',Z,ang1,ang2,obj.n_feat_per_rot);
            Fr1_3d=reshape(Fr1,[],obj.n1,obj.n2);
            err1=norm(Fr1_3d(:)-reshape(circshift(F0_3d,1,2),[],1))/(norm(F0_3d(:))+1e-20);

            % Test g2-shift: translate z by L/n2
            pos_tz=pos; pos_tz(:,3)=pos_tz(:,3)+obj.L_axial/obj.n2;
            Ftz=obj.build_features(pos_tz,Z,ang1,ang2,obj.n_feat_per_rot);
            Ftz_3d=reshape(Ftz,[],obj.n1,obj.n2);
            err2=norm(Ftz_3d(:)-reshape(circshift(F0_3d,1,3),[],1))/(norm(F0_3d(:))+1e-20);

            fprintf('Product structure check:\n');
            fprintf('  G1 shift (z-rotation):     %.2e %s\n', err1, tif(err1<1e-10,'PASS','WARN'));
            fprintf('  G2 shift (z-translation):  %.2e %s\n', err2, tif(err2<1e-10,'PASS','WARN'));
        end

        %% FOLD DATA ======================================================
        function X_sub = fold_to_factor(obj, X, factor)
            [ns,nf,~]=size(X);
            X_3d=reshape(X,ns,nf,obj.n1,obj.n2);
            if factor==1, X_sub=squeeze(mean(X_3d,4));  % keep G1
            else, X_sub=squeeze(mean(X_3d,3)); end      % keep G2
        end

        %% COMPARISON =====================================================
        function results = run_comparison(obj, target_col, varargin)
            p=inputParser; addParameter(p,'n_seeds',3); addParameter(p,'split',[0.7,0.15,0.15]);
            parse(p,varargin{:});
            n_seeds=p.Results.n_seeds; sp=p.Results.split;
            y=obj.properties_mat(:,target_col); n=obj.n_molecules;

            mn={'G1xG2_Ridge','G1xG2_MLP','G1_only_Ridge','G2_only_Ridge', ...
                'Z_n_Ridge','Standard_MLP','Invariant_MLP','Augmented_MLP'};
            nm=length(mn);
            R2_all=zeros(nm,n_seeds); RMSE_all=zeros(nm,n_seeds); params=zeros(nm,1);

            X_g1=obj.fold_to_factor(obj.X_tensor,1);
            X_g2=obj.fold_to_factor(obj.X_tensor,2);
            G_cyc=StarGAlgebra('cyclic',obj.n1*obj.n2);

            fprintf('\n============================================================\n');
            fprintf('  Product Group: Z_%d x Z_%d, %d mol, %d seeds\n',obj.n1,obj.n2,n,n_seeds);
            fprintf('============================================================\n');

            for seed=1:n_seeds
                fprintf('\n,  Seed %d/%d , \n',seed,n_seeds);
                rng(seed*111); idx=randperm(n);
                ntr=round(sp(1)*n); nva=round(sp(2)*n);
                tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
                ytr=y(tri); yva=y(vai); yte=y(tei);
                Xtr=obj.X_tensor(tri,:,:); Xva=obj.X_tensor(vai,:,:); Xte=obj.X_tensor(tei,:,:);

                % 1. G1xG2 + Ridge
                t0=tic; fprintf('  [1] G1xG2 Ridge...');
                [ftr,np]=extractStarGFeatures(Xtr,obj.G_prod,obj.n_feat_per_rot);
                fva=extractStarGFeatures(Xva,obj.G_prod,obj.n_feat_per_rot,np);
                fte=extractStarGFeatures(Xte,obj.G_prod,obj.n_feat_per_rot,np);
                [w,~]=obj.ridge_cv(ftr,ytr,fva,yva); yp=fte*w;
                R2_all(1,seed)=obj.r2(yte,yp); RMSE_all(1,seed)=sqrt(mean((yte-yp).^2));
                params(1)=numel(w);
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(1,seed),toc(t0));

                % 2. G1xG2 + MLP
                t0=tic; fprintf('  [2] G1xG2 MLP...');
                [wm,bm]=obj.train_mlp(ftr,ytr,fva,yva,[size(ftr,2),64,32,1],300,0.003);
                yp2=obj.predict_mlp(fte,wm,bm);
                R2_all(2,seed)=obj.r2(yte,yp2); RMSE_all(2,seed)=sqrt(mean((yte-yp2).^2));
                params(2)=sum(cellfun(@numel,wm))+sum(cellfun(@numel,bm));
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(2,seed),toc(t0));

                % 3. G1 only
                t0=tic; fprintf('  [3] G1 only (Z_%d)...',obj.n1);
                [f1,np1]=extractStarGFeatures(X_g1(tri,:,:),obj.G1,obj.n_feat_per_rot);
                fv1=extractStarGFeatures(X_g1(vai,:,:),obj.G1,obj.n_feat_per_rot,np1);
                ft1=extractStarGFeatures(X_g1(tei,:,:),obj.G1,obj.n_feat_per_rot,np1);
                [w1,~]=obj.ridge_cv(f1,ytr,fv1,yva); yp1=ft1*w1;
                R2_all(3,seed)=obj.r2(yte,yp1); RMSE_all(3,seed)=sqrt(mean((yte-yp1).^2));
                params(3)=numel(w1);
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(3,seed),toc(t0));

                % 4. G2 only
                t0=tic; fprintf('  [4] G2 only (Z_%d)...',obj.n2);
                [f2,np2]=extractStarGFeatures(X_g2(tri,:,:),obj.G2,obj.n_feat_per_rot);
                fv2=extractStarGFeatures(X_g2(vai,:,:),obj.G2,obj.n_feat_per_rot,np2);
                ft2=extractStarGFeatures(X_g2(tei,:,:),obj.G2,obj.n_feat_per_rot,np2);
                [w2,~]=obj.ridge_cv(f2,ytr,fv2,yva); yp2f=ft2*w2;
                R2_all(4,seed)=obj.r2(yte,yp2f); RMSE_all(4,seed)=sqrt(mean((yte-yp2f).^2));
                params(4)=numel(w2);
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(4,seed),toc(t0));

                % 5. Z_{n1*n2} cyclic (wrong structure)
                t0=tic; fprintf('  [5] Z_%d cyclic...',obj.n1*obj.n2);
                [fc,npc]=extractStarGFeatures(Xtr,G_cyc,obj.n_feat_per_rot);
                fvc=extractStarGFeatures(Xva,G_cyc,obj.n_feat_per_rot,npc);
                ftc=extractStarGFeatures(Xte,G_cyc,obj.n_feat_per_rot,npc);
                [wc,~]=obj.ridge_cv(fc,ytr,fvc,yva); ypc=ftc*wc;
                R2_all(5,seed)=obj.r2(yte,ypc); RMSE_all(5,seed)=sqrt(mean((yte-ypc).^2));
                params(5)=numel(wc);
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(5,seed),toc(t0));

                % 6. Standard MLP
                t0=tic; fprintf('  [6] Standard MLP...');
                Xs=squeeze(Xtr(:,:,1)); nfs=size(Xs,2);
                [mu6,s6]=deal(mean(Xs),std(Xs)+1e-8);
                [w6,b6]=obj.train_mlp((Xs-mu6)./s6,ytr,(squeeze(Xva(:,:,1))-mu6)./s6,yva,[nfs,64,32,1],300,0.003);
                yp6=obj.predict_mlp((squeeze(Xte(:,:,1))-mu6)./s6,w6,b6);
                R2_all(6,seed)=obj.r2(yte,yp6); RMSE_all(6,seed)=sqrt(mean((yte-yp6).^2));
                params(6)=sum(cellfun(@numel,w6))+sum(cellfun(@numel,b6));
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(6,seed),toc(t0));

                % 7. Invariant MLP
                t0=tic; fprintf('  [7] Invariant MLP...');
                fi_tr=[mean(Xtr,3),std(Xtr,0,3)];
                fi_va=[mean(Xva,3),std(Xva,0,3)];
                fi_te=[mean(Xte,3),std(Xte,0,3)];
                [mu7,s7]=deal(mean(fi_tr),std(fi_tr)+1e-8);
                [w7,b7]=obj.train_mlp((fi_tr-mu7)./s7,ytr,(fi_va-mu7)./s7,yva,[size(fi_tr,2),64,32,1],300,0.003);
                yp7=obj.predict_mlp((fi_te-mu7)./s7,w7,b7);
                R2_all(7,seed)=obj.r2(yte,yp7); RMSE_all(7,seed)=sqrt(mean((yte-yp7).^2));
                params(7)=sum(cellfun(@numel,w7))+sum(cellfun(@numel,b7));
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(7,seed),toc(t0));

                % 8. Augmented MLP
                t0=tic; fprintf('  [8] Augmented MLP...');
                nfa=obj.n_feat_per_rot; ng=obj.n1*obj.n2;
                Xa=reshape(permute(Xtr,[1,3,2]),[],nfa); ya=repmat(ytr,ng,1);
                Xav=reshape(permute(Xva,[1,3,2]),[],nfa); yav=repmat(yva,ng,1);
                [mu8,s8]=deal(mean(Xa),std(Xa)+1e-8);
                [w8,b8]=obj.train_mlp((Xa-mu8)./s8,ya,(Xav-mu8)./s8,yav,[nfa,64,32,1],200,0.003);
                yp8=obj.predict_mlp((squeeze(Xte(:,:,1))-mu8)./s8,w8,b8);
                R2_all(8,seed)=obj.r2(yte,yp8); RMSE_all(8,seed)=sqrt(mean((yte-yp8).^2));
                params(8)=sum(cellfun(@numel,w8))+sum(cellfun(@numel,b8));
                fprintf(' R2=%.4f (%.1fs)\n',R2_all(8,seed),toc(t0));
            end

            fprintf('\n============================================================\n');
            fprintf('  Product Group Results: Z_%d x Z_%d (%d seeds)\n',obj.n1,obj.n2,n_seeds);
            fprintf('============================================================\n');
            fprintf('%-22s %12s %12s %8s\n','Method','Test R2','RMSE','Params');
            fprintf('%s\n',repmat('-',1,56));
            for m=1:nm
                fprintf('%-22s %5.3f+/-%5.3f %6.4f+/-%6.4f %8d\n', ...
                    mn{m},mean(R2_all(m,:)),std(R2_all(m,:)),mean(RMSE_all(m,:)),std(RMSE_all(m,:)),params(m));
            end
            results.method_names=mn; results.R2=R2_all; results.RMSE=RMSE_all; results.params=params;
            obj.results=results;
        end

        %% FACTORIZATION DISCOVERY ========================================
        function report = discover_factorization(obj, target_col)
            y=obj.properties_mat(:,target_col); ng=obj.n1*obj.n2;
            fprintf('\n============================================================\n');
            fprintf('  Factorization Discovery: n=%d\n',ng);
            fprintf('============================================================\n');

            factors=[];
            for a=2:ng, b=ng/a;
                if b>=2 && b==floor(b) && a<=b, factors=[factors;a,b]; end
            end

            n=obj.n_molecules; rng(0); idx=randperm(n);
            ntr=round(0.5*n); nva=round(0.2*n);
            tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
            ytr=y(tri); yva=y(vai); yte=y(tei);
            report=struct('name',{},'r2',{});

            for f=1:size(factors,1)
                a=factors(f,1); b=factors(f,2);
                name=sprintf('Z_%d x Z_%d',a,b);
                try
                    Gab=StarGAlgebra('product',{StarGAlgebra('cyclic',a),StarGAlgebra('cyclic',b)});
                    [ftr,np]=extractStarGFeatures(obj.X_tensor(tri,:,:),Gab,obj.n_feat_per_rot);
                    fva=extractStarGFeatures(obj.X_tensor(vai,:,:),Gab,obj.n_feat_per_rot,np);
                    fte=extractStarGFeatures(obj.X_tensor(tei,:,:),Gab,obj.n_feat_per_rot,np);
                    [w,~]=obj.ridge_cv(ftr,ytr,fva,yva);
                    r2=obj.r2(yte,fte*w);
                catch, r2=NaN; end
                report(end+1)=struct('name',name,'r2',r2);
                fprintf('  %-15s R2=%+.4f\n',name,r2);
            end
            name=sprintf('Z_%d (cyclic)',ng);
            try
                Gc=StarGAlgebra('cyclic',ng);
                [ftr,np]=extractStarGFeatures(obj.X_tensor(tri,:,:),Gc,obj.n_feat_per_rot);
                fva=extractStarGFeatures(obj.X_tensor(vai,:,:),Gc,obj.n_feat_per_rot,np);
                fte=extractStarGFeatures(obj.X_tensor(tei,:,:),Gc,obj.n_feat_per_rot,np);
                [w,~]=obj.ridge_cv(ftr,ytr,fva,yva);
                r2=obj.r2(yte,fte*w);
            catch, r2=NaN; end
            report(end+1)=struct('name',name,'r2',r2);
            fprintf('  %-15s R2=%+.4f\n',name,r2);

            [~,bi]=max([report.r2]);
            fprintf('\n  >>> Best: %s (R2=%.4f)\n',report(bi).name,report(bi).r2);
        end

        %% ML UTILITIES ===================================================
        function [w,bl]=ridge_cv(~,Xr,yr,Xv,yv)
            lams=[1e-3,0.01,0.1,1,10,100,1e3]; be=Inf; bl=0.01; pp=size(Xr,2);
            R=eye(pp); R(1,1)=0;
            for lam=lams, wt=(Xr'*Xr+lam*R)\(Xr'*yr);
                e=mean((yv-Xv*wt).^2); if e<be,be=e;bl=lam;end;end
            w=(Xr'*Xr+bl*R)\(Xr'*yr);
        end
        function [W,B]=train_mlp(~,X,y,Xv,yv,layers,maxep,lr)
            nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
            for l=1:nl,fi=layers(l);W{l}=randn(fi,layers(l+1))*sqrt(2/fi);B{l}=zeros(1,layers(l+1));end
            mW=cellfun(@(w)zeros(size(w)),W,'Uni',0);vW=mW;mB=cellfun(@(b)zeros(size(b)),B,'Uni',0);vB=mB;
            b1=.9;b2=.999;ea=1e-8;bv=Inf;pa=20;wa=0;Wb=W;Bb=B;nn=size(X,1);bs=min(256,nn);
            for ep=1:maxep,pm=randperm(nn);
                for s=1:bs:nn,bi=pm(s:min(s+bs-1,nn));Xb=X(bi,:);yb=y(bi);
                    A=cell(nl+1,1);A{1}=Xb;
                    for l=1:nl,Zl=A{l}*W{l}+B{l};if l<nl,A{l+1}=max(0,Zl);else,A{l+1}=Zl;end;end
                    dZ=(A{nl+1}-yb)/size(Xb,1);
                    for l=nl:-1:1,gW=A{l}'*dZ;gB=sum(dZ,1);if l>1,dZ=(dZ*W{l}').*(A{l}>0);end
                        mW{l}=b1*mW{l}+(1-b1)*gW;vW{l}=b2*vW{l}+(1-b2)*gW.^2;
                        mB{l}=b1*mB{l}+(1-b1)*gB;vB{l}=b2*vB{l}+(1-b2)*gB.^2;
                        W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
                        B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);end;end
                yp=Xv;for l=1:nl,yp=yp*W{l}+B{l};if l<nl,yp=max(0,yp);end;end
                vm=mean((yv-yp).^2);if vm<bv,bv=vm;wa=0;Wb=W;Bb=B;else,wa=wa+1;if wa>=pa,break;end;end;end
            W=Wb;B=Bb;
        end
        function yp=predict_mlp(~,X,W,B),yp=X;for l=1:length(W),yp=yp*W{l}+B{l};if l<length(W),yp=max(0,yp);end;end;end
        function v=r2(~,yt,yp),v=1-sum((yt-yp).^2)/(sum((yt-mean(yt)).^2)+1e-20);end
    end
end

function s=tif(c,a,b), if c,s=a;else,s=b;end; end
