%% ========================================================================
%% QM9_experiment.m
%% Real-world validation on QM9 molecular property prediction
%%
%% LAPTOP-FRIENDLY: default 1000 molecules, timing printed per step.
%% LH & Claude 2026
%% ========================================================================

classdef QM9_experiment < handle
    properties
        data_dir; n_molecules; n_rotations; max_atoms
        coords; atomic_numbers; properties_mat
        G; X_tensor; n_feat_per_rot; results
        is_real_data  % true if loaded from .xyz files
    end

    methods
        function obj = QM9_experiment(data_dir, n_rotations, varargin)
            p = inputParser;
            addRequired(p,'data_dir'); addRequired(p,'n_rotations');
            addParameter(p,'max_atoms',29); addParameter(p,'n_molecules',Inf);
            parse(p, data_dir, n_rotations, varargin{:});
            obj.data_dir = p.Results.data_dir;
            obj.n_rotations = p.Results.n_rotations;
            obj.max_atoms = p.Results.max_atoms;
            obj.n_molecules = 0;
            obj.is_real_data = false;
            obj.G = StarGAlgebra('cyclic', obj.n_rotations);
            fprintf('QM9 Experiment: |G| = %d (Z_%d)\n', obj.n_rotations, obj.n_rotations);
        end

        %% DATA LOADING ===================================================
        function obj = load_data(obj, n_max)
            if nargin < 2, n_max = Inf; end
            if isfolder(obj.data_dir)
                xyz_files = dir(fullfile(obj.data_dir, '*.xyz'));
                if ~isempty(xyz_files)
                    obj = obj.load_xyz_files(xyz_files, n_max);
                    obj.is_real_data = true;
                    return;
                end
            end
            fprintf('No .xyz files in "%s". Generating synthetic molecules.\n', obj.data_dir);
            obj = obj.generate_synthetic(min(n_max, 1000));
        end

        function obj = load_xyz_files(obj, xyz_files, n_max)
            n_files = min(length(xyz_files), n_max);
            fprintf('Loading %d QM9 .xyz files...\n', n_files);
            t0 = tic;

            obj.coords = cell(n_files,1); obj.atomic_numbers = cell(n_files,1);
            obj.properties_mat = zeros(n_files, 15);
            elem_map = containers.Map({'H','C','N','O','F','S','Cl','Br','I','P','Si','B'}, ...
                {1,6,7,8,9,16,17,35,53,15,14,5});
            valid = 0;
            for f = 1:n_files
                fpath = fullfile(xyz_files(f).folder, xyz_files(f).name);
                fid = -1;
                try
                    fid = fopen(fpath,'r');
                    n_atoms = str2double(fgetl(fid));
                    if isnan(n_atoms) || n_atoms < 1, fclose(fid); continue; end

                    prop_line = strrep(fgetl(fid),'*^','e');
                    tokens = strsplit(strtrim(prop_line));
                    props = zeros(1,min(15,length(tokens)-2));
                    for t=1:length(props)
                        val = str2double(tokens{t+2});
                        if ~isnan(val), props(t)=val; end
                    end

                    Z = zeros(n_atoms,1); pos = zeros(n_atoms,3);
                    for a = 1:n_atoms
                        line = fgetl(fid);
                        if ~ischar(line), break; end
                        % QM9 uses tabs and may have extra columns
                        line = strrep(line, char(9), ' ');
                        parts = strsplit(strtrim(line));
                        if length(parts) < 4, break; end
                        elem = parts{1};
                        if isKey(elem_map,elem), Z(a)=elem_map(elem); else, Z(a)=6; end
                        pos(a,:) = [str2double(strrep(parts{2},'*^','e')), ...
                                    str2double(strrep(parts{3},'*^','e')), ...
                                    str2double(strrep(parts{4},'*^','e'))];
                    end
                    fclose(fid);

                    if any(isnan(pos(:))), continue; end

                    valid = valid+1;
                    obj.coords{valid} = pos;
                    obj.atomic_numbers{valid} = Z;
                    obj.properties_mat(valid,1:length(props)) = props;
                catch ME
                    if fid > 0, fclose(fid); end
                end
                if mod(f,2000)==0, fprintf('  %d / %d (%.1fs)\n', f, n_files, toc(t0)); end
            end
            obj.coords = obj.coords(1:valid); obj.atomic_numbers = obj.atomic_numbers(1:valid);
            obj.properties_mat = obj.properties_mat(1:valid,:); obj.n_molecules = valid;
            fprintf('Loaded %d molecules in %.1fs.\n', valid, toc(t0));
        end

        function obj = generate_synthetic(obj, n_mol)
            fprintf('Generating %d synthetic molecules...\n', n_mol);
            rng(42,'twister'); atom_types = [1,6,7,8,9];
            obj.coords = cell(n_mol,1); obj.atomic_numbers = cell(n_mol,1);
            obj.properties_mat = zeros(n_mol, 15);
            n_rot = obj.n_rotations;
            angles = (0:n_rot-1)*2*pi/n_rot;
            e1 = [cos(angles); sin(angles); zeros(1,n_rot)];
            for i = 1:n_mol
                na = randi([4,12]); pos = randn(na,3)*1.5; pos = pos-mean(pos,1);
                Z = atom_types(randi(5,na,1))';
                obj.coords{i} = pos; obj.atomic_numbers{i} = Z;
                D = squareform(pdist(pos)); dd=D(D>0);
                w = Z/(sum(Z)+1e-10);
                pw = abs(fft(w'*pos*e1)).^2;
                obj.properties_mat(i,8) = mean(dd) + 0.003*pw(2) + 0.002*pw(3) + sum(Z.^2)/500;
            end
            obj.n_molecules = n_mol;
            fprintf('Target = d_mean + angular_freq_power + Z_term\n');
        end

        %% ANGULAR PROJECTION FEATURES ====================================
        function obj = compute_rotated_features(obj, varargin)
            p = inputParser; addParameter(p,'n_feat',48); parse(p,varargin{:});
            n_feat_target = p.Results.n_feat;
            n_mol = obj.n_molecules; n_rot = obj.n_rotations;
            angles = (0:n_rot-1)*2*pi/n_rot;

            feat0 = obj.angular_features(obj.coords{1}, obj.atomic_numbers{1}, angles, n_feat_target);
            n_feat = size(feat0,1); obj.n_feat_per_rot = n_feat;
            X = zeros(n_mol, n_feat, n_rot);
            t0 = tic;
            fprintf('Computing features: %d molecules x %d rotations...\n', n_mol, n_rot);
            for mol = 1:n_mol
                pos = obj.coords{mol} - mean(obj.coords{mol},1);
                X(mol,:,:) = obj.angular_features(pos, obj.atomic_numbers{mol}, angles, n_feat_target);
                if mod(mol,500)==0, fprintf('  %d/%d (%.1fs)\n', mol, n_mol, toc(t0)); end
            end
            obj.X_tensor = X;
            fprintf('Feature tensor: %d x %d x %d (%.1fs)\n', size(X), toc(t0));
            obj.verify_cyclic_structure();
        end

        function F = angular_features(~, coords, Z, angles, n_feat_target)
            n_rot = length(angles); n_atoms = size(coords,1);
            Zn = Z(:); w = Zn/(sum(Zn)+1e-10);
            coords_c = coords - mean(coords, 1);
            feat_list = {};

            e1 = [cos(angles); sin(angles); zeros(1,n_rot)];
            e2 = [-sin(angles); cos(angles); zeros(1,n_rot)];

            p1 = coords_c*e1; p2 = coords_c*e2;
            pz = coords_c*repmat([0;0;1],1,n_rot);

            % Equivariant moments
            feat_list{end+1}=w'*p1; feat_list{end+1}=w'*p2; feat_list{end+1}=w'*pz;
            feat_list{end+1}=w'*(p1.^2); feat_list{end+1}=w'*(p2.^2); feat_list{end+1}=w'*(pz.^2);
            feat_list{end+1}=w'*(p1.*p2); feat_list{end+1}=w'*(p1.*pz); feat_list{end+1}=w'*(p2.*pz);
            feat_list{end+1}=Zn'*p1; feat_list{end+1}=Zn'*p2;
            feat_list{end+1}=Zn'*(p1.^2); feat_list{end+1}=Zn'*(p2.^2);
            feat_list{end+1}=Zn'*(p1.*p2); feat_list{end+1}=Zn'*(p1.*pz);
            feat_list{end+1}=w'*(p1.^3); feat_list{end+1}=w'*(p2.^3);
            feat_list{end+1}=Zn'*(p1.^3); feat_list{end+1}=Zn'*(p2.^3);

            [~,si]=sort(Z,'descend');
            for k=1:min(4,n_atoms)
                idx=si(k);
                feat_list{end+1}=p1(idx,:); feat_list{end+1}=p2(idx,:); feat_list{end+1}=pz(idx,:);
            end
            for pp=1:min(3,n_atoms-1)
                ii=si(1); jj=si(min(pp+1,n_atoms));
                d_ij=norm(coords_c(ii,:)-coords_c(jj,:))+1e-8;
                feat_list{end+1}=Z(ii)*Z(jj)/d_ij*(p1(ii,:)-p1(jj,:));
                feat_list{end+1}=Z(ii)*Z(jj)/d_ij*(p2(ii,:)-p2(jj,:));
            end

            % Invariant features
            if n_atoms>=2
                D=pdist(coords_c);
                for v=[mean(D),std(D),min(D),max(D)], feat_list{end+1}=repmat(v,1,n_rot); end
            else
                for k=1:4, feat_list{end+1}=zeros(1,n_rot); end
            end
            r=sqrt(sum(coords_c.^2,2)); rs=sort(r,'descend');
            for k=1:min(4,n_atoms), feat_list{end+1}=repmat(rs(k),1,n_rot); end
            for k=n_atoms+1:4, feat_list{end+1}=zeros(1,n_rot); end
            feat_list{end+1}=repmat(sum(Zn.^2)/100,1,n_rot);
            feat_list{end+1}=repmat(mean(Zn),1,n_rot);
            feat_list{end+1}=repmat(n_atoms,1,n_rot);

            F=cell2mat(feat_list');
            nr=size(F,1);
            if nr<n_feat_target, F=[F;zeros(n_feat_target-nr,n_rot)];
            elseif nr>n_feat_target, F=F(1:n_feat_target,:); end
        end

        function verify_cyclic_structure(obj)
            if obj.n_molecules<1, return; end
            mi=min(5,obj.n_molecules);
            pos=obj.coords{mi}-mean(obj.coords{mi},1);
            Z=obj.atomic_numbers{mi}; nr=obj.n_rotations; ang=(0:nr-1)*2*pi/nr;
            F0=obj.angular_features(pos,Z,ang,obj.n_feat_per_rot);
            th=2*pi/nr; R1=[cos(th),-sin(th),0;sin(th),cos(th),0;0,0,1];
            Fr=obj.angular_features((R1*pos')',Z,ang,obj.n_feat_per_rot);
            err = norm(Fr(:)-reshape(circshift(F0,1,2),[],1))/(norm(F0(:))+1e-20);
            fprintf('Cyclic check: %.2e', err);
            if err<1e-10, fprintf(' PASS\n\n'); else, fprintf(' (%.2e)\n\n', err); end
        end

        %% COMPARISON =====================================================
        function results = run_comparison(obj, target_col, varargin)
            p=inputParser; addParameter(p,'n_seeds',3); addParameter(p,'split',[0.7,0.15,0.15]);
            parse(p, varargin{:});
            n_seeds=p.Results.n_seeds; sp=p.Results.split;
            y=obj.properties_mat(:,target_col); n=obj.n_molecules;

            mn={'starG_SVD_Ridge','Neural_starG','Standard_MLP','Invariant_MLP','Augmented_MLP'};
            nm=length(mn);
            R2_all=zeros(nm,n_seeds); RMSE_all=zeros(nm,n_seeds);
            rot_var=zeros(nm,n_seeds); params=zeros(nm,1);

            fprintf('\n============================================================\n');
            fprintf('  Comparison: col=%d, %d mol, %d seeds\n', target_col, n, n_seeds);
            fprintf('============================================================\n');

            for seed=1:n_seeds
                fprintf('\n,  Seed %d/%d , \n', seed, n_seeds);
                rng(seed*111,'twister'); idx=randperm(n);
                ntr=round(sp(1)*n); nva=round(sp(2)*n);
                tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
                Xtr=obj.X_tensor(tri,:,:); Xva=obj.X_tensor(vai,:,:); Xte=obj.X_tensor(tei,:,:);
                ytr=y(tri); yva=y(vai); yte=y(tei);

                % 1. star_G-SVD + Ridge
                t0=tic; fprintf('  [1] star_G-SVD + Ridge...');
                [ftr,np]=extractStarGFeatures(Xtr,obj.G,obj.n_feat_per_rot);
                fva=extractStarGFeatures(Xva,obj.G,obj.n_feat_per_rot,np);
                fte=extractStarGFeatures(Xte,obj.G,obj.n_feat_per_rot,np);
                [w1,~]=obj.ridge_cv(ftr,ytr,fva,yva); yp1=fte*w1;
                R2_all(1,seed)=obj.r2(yte,yp1); RMSE_all(1,seed)=sqrt(mean((yte-yp1).^2));
                rot_var(1,seed)=obj.measure_rotation_variance(Xte,np,w1);
                params(1)=numel(w1);
                fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(1,seed), rot_var(1,seed), toc(t0));

                % 2. Neural star_G
                t0=tic; fprintf('  [2] Neural star_G...');
                nin=size(ftr,2); [wm,bm]=obj.train_mlp(ftr,ytr,fva,yva,[nin,64,32,1],300,0.003);
                yp2=obj.predict_mlp(fte,wm,bm);
                R2_all(2,seed)=obj.r2(yte,yp2); RMSE_all(2,seed)=sqrt(mean((yte-yp2).^2));
                rot_var(2,seed)=obj.measure_rotation_variance(Xte,np,[],wm,bm);
                params(2)=sum(cellfun(@numel,wm))+sum(cellfun(@numel,bm));
                fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(2,seed), rot_var(2,seed), toc(t0));

                % 3. Standard MLP
                t0=tic; fprintf('  [3] Standard MLP...');
                Xs1=squeeze(Xtr(:,:,1)); Xsv=squeeze(Xva(:,:,1)); Xst=squeeze(Xte(:,:,1));
                [mu3,s3]=deal(mean(Xs1),std(Xs1)+1e-8);
                [w3,b3]=obj.train_mlp((Xs1-mu3)./s3,ytr,(Xsv-mu3)./s3,yva,[size(Xs1,2),64,32,1],300,0.003);
                yp3=obj.predict_mlp((Xst-mu3)./s3,w3,b3);
                R2_all(3,seed)=obj.r2(yte,yp3); RMSE_all(3,seed)=sqrt(mean((yte-yp3).^2));
                rot_var(3,seed)=obj.measure_rv_raw(Xte,mu3,s3,w3,b3);
                params(3)=sum(cellfun(@numel,w3))+sum(cellfun(@numel,b3));
                fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(3,seed), rot_var(3,seed), toc(t0));

                % 4. Invariant MLP
                t0=tic; fprintf('  [4] Invariant MLP...');
                fi_tr=[mean(Xtr,3),std(Xtr,0,3),min(Xtr,[],3),max(Xtr,[],3)];
                fi_va=[mean(Xva,3),std(Xva,0,3),min(Xva,[],3),max(Xva,[],3)];
                fi_te=[mean(Xte,3),std(Xte,0,3),min(Xte,[],3),max(Xte,[],3)];
                [mu4,s4]=deal(mean(fi_tr),std(fi_tr)+1e-8);
                fi_tr=(fi_tr-mu4)./s4; fi_va=(fi_va-mu4)./s4; fi_te=(fi_te-mu4)./s4;
                [w4,b4]=obj.train_mlp(fi_tr,ytr,fi_va,yva,[size(fi_tr,2),64,32,1],300,0.003);
                yp4=obj.predict_mlp(fi_te,w4,b4);
                R2_all(4,seed)=obj.r2(yte,yp4); RMSE_all(4,seed)=sqrt(mean((yte-yp4).^2));
                rot_var(4,seed)=0; params(4)=sum(cellfun(@numel,w4))+sum(cellfun(@numel,b4));
                fprintf(' R2=%.4f (%.1fs)\n', R2_all(4,seed), toc(t0));

                % 5. Augmented MLP
                t0=tic; fprintf('  [5] Augmented MLP...');
                nf=obj.n_feat_per_rot;
                Xa5=reshape(permute(Xtr,[1,3,2]),[],nf); ya5=repmat(ytr,obj.n_rotations,1);
                Xav=reshape(permute(Xva,[1,3,2]),[],nf); yav=repmat(yva,obj.n_rotations,1);
                [mu5,s5]=deal(mean(Xa5),std(Xa5)+1e-8);
                [w5,b5]=obj.train_mlp((Xa5-mu5)./s5,ya5,(Xav-mu5)./s5,yav,[nf,64,32,1],300,0.003);
                yp5=obj.predict_mlp((squeeze(Xte(:,:,1))-mu5)./s5,w5,b5);
                R2_all(5,seed)=obj.r2(yte,yp5); RMSE_all(5,seed)=sqrt(mean((yte-yp5).^2));
                rot_var(5,seed)=obj.measure_rv_raw(Xte,mu5,s5,w5,b5);
                params(5)=sum(cellfun(@numel,w5))+sum(cellfun(@numel,b5));
                fprintf(' R2=%.4f rv=%.1e (%.1fs)\n', R2_all(5,seed), rot_var(5,seed), toc(t0));
            end

            fprintf('\n============================================================\n');
            fprintf('  RESULTS (mean +/- std, %d seeds)\n', n_seeds);
            fprintf('============================================================\n');
            fprintf('%-22s %12s %12s %12s %8s\n','Method','Test R2','RMSE','Rot Var','Params');
            fprintf('%s\n', repmat('-',1,70));
            for m=1:nm
                rv=rot_var(m,:);
                if m==4, rvs='~0 (exact)'; else, rvs=sprintf('%.2e',mean(rv)); end
                fprintf('%-22s %5.3f+/-%5.3f %5.3f+/-%5.3f %12s %8d\n', ...
                    mn{m},mean(R2_all(m,:)),std(R2_all(m,:)),mean(RMSE_all(m,:)),std(RMSE_all(m,:)),rvs,params(m));
            end
            results.method_names=mn; results.R2=R2_all; results.RMSE=RMSE_all;
            results.rot_var=rot_var; results.params=params;
            obj.results=results;
        end

        %% ROTATION VARIANCE ==============================================
        function rv = measure_rotation_variance(obj,Xt,np,wr,wm,bm)
            if nargin<5, wm=[]; end; if nargin<6, bm=[]; end
            nt=min(50,size(Xt,1));  % reduced for speed
            pps=zeros(nt,obj.n_rotations);
            for g=1:obj.n_rotations
                Xs=circshift(Xt(1:nt,:,:),g-1,3);
                fg=extractStarGFeatures(Xs,obj.G,obj.n_feat_per_rot,np);
                if ~isempty(wm), pps(:,g)=obj.predict_mlp(fg,wm,bm);
                else, pps(:,g)=fg*wr; end
            end
            rv=mean(var(pps,0,2));
        end

        function rv = measure_rv_raw(obj,Xt,mu,sig,W,B)
            nt=min(50,size(Xt,1));
            pps=zeros(nt,obj.n_rotations);
            for g=1:obj.n_rotations
                Xg=squeeze(Xt(1:nt,:,g));
                pps(:,g)=obj.predict_mlp((Xg-mu)./sig,W,B);
            end
            rv=mean(var(pps,0,2));
        end

        %% ML UTILITIES ===================================================
        function [w,bl] = ridge_cv(~,Xr,yr,Xv,yv)
            lams=[0.001,0.01,0.1,1,10,100,1000]; be=Inf; bl=0.01; pp=size(Xr,2);
            R=eye(pp); R(1,1)=0;
            for lam=lams, wt=(Xr'*Xr+lam*R)\(Xr'*yr);
                e=mean((yv-Xv*wt).^2); if e<be, be=e; bl=lam; end; end
            w=(Xr'*Xr+bl*R)\(Xr'*yr);
        end

        function [W,B] = train_mlp(~,X,y,Xv,yv,layers,maxep,lr)
            nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
            for l=1:nl, fi=layers(l); W{l}=randn(fi,layers(l+1))*sqrt(2/fi); B{l}=zeros(1,layers(l+1)); end
            mW=cellfun(@(w)zeros(size(w)),W,'Uni',0); vW=mW;
            mB=cellfun(@(b)zeros(size(b)),B,'Uni',0); vB=mB;
            b1=0.9;b2=0.999;ea=1e-8; bv=Inf; pa=25; wa=0; Wb=W; Bb=B;
            nn=size(X,1); bs=min(256,nn);
            for ep=1:maxep
                pm=randperm(nn);
                for s=1:bs:nn
                    bi=pm(s:min(s+bs-1,nn)); Xb=X(bi,:); yb=y(bi);
                    A=cell(nl+1,1); A{1}=Xb;
                    for l=1:nl, Zl=A{l}*W{l}+B{l};
                        if l<nl, A{l+1}=max(0,Zl); else, A{l+1}=Zl; end; end
                    dZ=(A{nl+1}-yb)/size(Xb,1);
                    for l=nl:-1:1
                        gW=A{l}'*dZ; gB=sum(dZ,1);
                        if l>1, dZ=(dZ*W{l}').*(A{l}>0); end
                        mW{l}=b1*mW{l}+(1-b1)*gW; vW{l}=b2*vW{l}+(1-b2)*gW.^2;
                        mB{l}=b1*mB{l}+(1-b1)*gB; vB{l}=b2*vB{l}+(1-b2)*gB.^2;
                        W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
                        B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);
                    end
                end
                yp=Xv; for l=1:nl, yp=yp*W{l}+B{l}; if l<nl, yp=max(0,yp); end; end
                vm=mean((yv-yp).^2);
                if vm<bv, bv=vm; wa=0; Wb=W; Bb=B; else, wa=wa+1; if wa>=pa, break; end; end
            end
            W=Wb; B=Bb;
        end

        function yp = predict_mlp(~,X,W,B)
            yp=X; for l=1:length(W), yp=yp*W{l}+B{l}; if l<length(W), yp=max(0,yp); end; end
        end

        function v = r2(~,yt,yp)
            v=1-sum((yt-yp).^2)/(sum((yt-mean(yt)).^2)+1e-20);
        end
    end
end
