tensor-group-sym / experiments / diagnose_ridge.m
diagnose_ridge.m
Raw
%% diagnose_ridge.m - Quick verification
function diagnose_ridge()
    thisDir = fileparts(mfilename('fullpath'));
    rootDir = fileparts(thisDir);
    addpath(fullfile(rootDir, 'core'));
    addpath(fullfile(rootDir, 'experiments'));

    fprintf('\n=== QUICK DIAGNOSTIC ===\n\n');
    n_rot = 12; G = StarGAlgebra('cyclic', n_rot);
    rng(42); n_mol = 500; atom_types = [1,6,7,8,9];
    angles = (0:n_rot-1)*2*pi/n_rot;
    e1 = [cos(angles); sin(angles); zeros(1,n_rot)];

    y = zeros(n_mol,1);
    coords_all = cell(n_mol,1); Z_all = cell(n_mol,1);
    for i = 1:n_mol
        na = randi([4,12]); pos = randn(na,3)*1.5; pos = pos-mean(pos,1);
        Z = atom_types(randi(5,na,1))';
        coords_all{i}=pos; Z_all{i}=Z;
        D = squareform(pdist(pos)); dd=D(D>0);
        w = Z/(sum(Z)+1e-10);
        pw = abs(fft(w'*pos*e1)).^2;
        y(i) = mean(dd) + 0.003*pw(2) + 0.002*pw(3) + sum(Z.^2)/500;
    end
    fprintf('y: mean=%.3f std=%.3f\n', mean(y), std(y));

    exp = QM9_experiment('dummy',n_rot);
    exp.coords=coords_all; exp.atomic_numbers=Z_all;
    exp.n_molecules=n_mol; exp.properties_mat=zeros(n_mol,15);
    exp.properties_mat(:,8)=y;
    exp = exp.compute_rotated_features('n_feat',48);
    X = exp.X_tensor;

    [feat,np] = extractStarGFeatures(X,G,48);
    fprintf('Features: %d x %d (incl intercept)\n', size(feat));
    fprintf('Has intercept: %s\n', mat2str(all(feat(:,1)==1)));

    ntr=round(0.7*n_mol); idx=randperm(n_mol);
    Xtr=feat(idx(1:ntr),:); ytr=y(idx(1:ntr));
    Xva=feat(idx(ntr+1:end),:); yva=y(idx(ntr+1:end));
    pp=size(Xtr,2); R=eye(pp); R(1,1)=0;

    fprintf('\nRidge (intercept-exempt):\n');
    for lam=[0.001,0.01,0.1,1,10]
        w=(Xtr'*Xtr+lam*R)\(Xtr'*ytr);
        yp=Xva*w;
        r2=1-sum((yva-yp).^2)/(sum((yva-mean(yva)).^2)+1e-20);
        fprintf('  lam=%6.3f: R2=%+.4f  mean_pred=%.3f mean_true=%.3f\n', lam, r2, mean(yp), mean(yva));
    end
    fprintf('\n=== END ===\n');
end