%% ========================================================================
%% product_group_experiment.m
%% Demonstrate the compositional advantage of star_G over product groups
%%
%% DESIGN:
%% G1 = Z_n1 (angular measurement in xy-plane: shifts under z-rotation)
%% G2 = Z_n2 (axial measurement along z: shifts under z-translation)
%% G = G1 x G2 (commuting product group, order n1*n2)
%%
%% The two group actions COMMUTE because:
%% R_z changes (x,y) but not z
%% T_z changes z but not (x,y)
%%
%% Features at (g1,g2) include:
%% - angular projections (shift in g1, constant in g2)
%% - axial periodic functions (constant in g1, shift in g2)
%% - coupled products (shift in both)
%%
%% Target depends on 2D Fourier content. The COUPLED (freq1>0, freq2>0)
%% terms dominate, so only the product group can capture the full signal.
%%
%% LH & Claude 2026
%% ========================================================================
classdef product_group_experiment < handle
properties
n1; n2; n_molecules
coords; atomic_numbers; properties_mat
G_prod; G1; G2
X_tensor; n_feat_per_rot
is_real_data; results
L_axial % characteristic length for axial embedding
end
methods
function obj = product_group_experiment(n1, n2, varargin)
p = inputParser;
addRequired(p,'n1'); addRequired(p,'n2');
addParameter(p,'n_molecules',1000);
parse(p, n1, n2, varargin{:});
obj.n1 = n1; obj.n2 = n2;
obj.n_molecules = 0;
obj.L_axial = 6; % z-range for periodic embedding
obj.is_real_data = false;
obj.G1 = StarGAlgebra('cyclic', n1);
obj.G2 = StarGAlgebra('cyclic', n2);
obj.G_prod = StarGAlgebra('product', {obj.G1, obj.G2});
fprintf('Product Group: G = Z_%d x Z_%d (order %d)\n', n1, n2, obj.G_prod.n);
fprintf(' G1 = Z_%d (xy-angular, shifts under z-rotation)\n', n1);
fprintf(' G2 = Z_%d (z-axial, shifts under z-translation)\n', n2);
end
%% DATA ============================================================
function obj = load_data(obj, data_dir, n_max)
if nargin < 3, n_max = 1000; end
if nargin >= 2 && ~isempty(data_dir) && isfolder(data_dir)
xyz_files = dir(fullfile(data_dir, '*.xyz'));
if ~isempty(xyz_files)
obj = obj.load_xyz_files(xyz_files, n_max);
obj.is_real_data = true; return;
end
end
obj = obj.generate_synthetic(min(n_max, 1000));
end
function obj = load_xyz_files(obj, xyz_files, n_max)
n_files = min(length(xyz_files), n_max);
fprintf('Loading %d .xyz files...\n', n_files); t0=tic;
obj.coords=cell(n_files,1); obj.atomic_numbers=cell(n_files,1);
obj.properties_mat=zeros(n_files,15);
em=containers.Map({'H','C','N','O','F','S','Cl','Br','I','P','Si','B'},{1,6,7,8,9,16,17,35,53,15,14,5});
valid=0;
for f=1:n_files
fp=fullfile(xyz_files(f).folder,xyz_files(f).name); fid=-1;
try
fid=fopen(fp,'r'); na=str2double(fgetl(fid));
if isnan(na)||na<1, fclose(fid); continue; end
pl=strrep(fgetl(fid),'*^','e'); tk=strsplit(strtrim(pl));
pr=zeros(1,min(15,length(tk)-2));
for t=1:length(pr), v=str2double(tk{t+2}); if ~isnan(v), pr(t)=v; end; end
Z=zeros(na,1); pos=zeros(na,3);
for a=1:na
ln=strrep(fgetl(fid),char(9),' '); pts=strsplit(strtrim(ln));
if length(pts)<4, break; end
if isKey(em,pts{1}), Z(a)=em(pts{1}); else, Z(a)=6; end
pos(a,:)=[str2double(strrep(pts{2},'*^','e')),str2double(strrep(pts{3},'*^','e')),str2double(strrep(pts{4},'*^','e'))];
end
fclose(fid); if any(isnan(pos(:))), continue; end
valid=valid+1; obj.coords{valid}=pos; obj.atomic_numbers{valid}=Z;
obj.properties_mat(valid,1:length(pr))=pr;
catch, if fid>0, fclose(fid); end; end
if mod(f,2000)==0, fprintf(' %d/%d (%.1fs)\n',f,n_files,toc(t0)); end
end
obj.coords=obj.coords(1:valid); obj.atomic_numbers=obj.atomic_numbers(1:valid);
obj.properties_mat=obj.properties_mat(1:valid,:); obj.n_molecules=valid;
fprintf('Loaded %d molecules (%.1fs)\n', valid, toc(t0));
end
function obj = generate_synthetic(obj, n_mol)
fprintf('Generating %d synthetic molecules...\n', n_mol);
rng(42,'twister'); atom_types=[1,6,7,8,9];
obj.coords=cell(n_mol,1); obj.atomic_numbers=cell(n_mol,1);
obj.properties_mat=zeros(n_mol,15);
obj.is_real_data = false;
n1=obj.n1; n2=obj.n2;
ang1=(0:n1-1)*2*pi/n1; ang2=(0:n2-1)*2*pi/n2;
% Product group Fourier matrix
F_prod = obj.G_prod.F;
for i=1:n_mol
na=randi([4,12]); pos=randn(na,3)*1.5; pos=pos-mean(pos,1);
Z=atom_types(randi(5,na,1))';
obj.coords{i}=pos; obj.atomic_numbers{i}=Z;
% Compute features for this molecule
F_all = obj.build_features(pos, Z, ang1, ang2, 36);
% Row 7 = first coupled row: w'*(p1.*az)
% Row 8 = second coupled row: Z'*(p1.*az)
% 2D Fourier via kron(F1,F2)
r7_hat = F_all(7,:) * F_prod;
r8_hat = F_all(8,:) * F_prod;
pw7 = abs(r7_hat).^2;
pw8 = abs(r8_hat).^2;
% 2D freq indices: (f1,f2) -> (f2-1)*n1 + f1 (1-indexed)
i22 = (2-1)*n1+2; % coupled (1,1) in 0-indexed
i23 = (3-1)*n1+2; % coupled (1,2)
i32 = (2-1)*n1+3; % coupled (2,1)
i20 = 2; % axis-1 only
i02 = (2-1)*n1+1; % axis-2 only
% Target: LINEAR in per-row 2D Fourier power
% Coupled terms dominate (coeff 13) vs single-axis (coeff 2)
y_c = 5.0*pw7(i22) + 3.0*pw7(i23) + 3.0*pw7(i32) + 2.0*pw8(i22);
y_1 = 1.0*pw7(i20);
y_2 = 1.0*pw7(i02);
obj.properties_mat(i,8) = y_c + y_1 + y_2;
end
obj.n_molecules=n_mol;
y = obj.properties_mat(1:n_mol,8);
fprintf('Target: linear in per-row 2D Fourier power of coupled rows\n');
fprintf(' Coupled coeff=13 vs single-axis=2 (6.5x ratio)\n');
fprintf(' y range: [%.2f, %.2f], std=%.2f\n', min(y), max(y), std(y));
end
%% FEATURE COMPUTATION ============================================
function obj = compute_features(obj, varargin)
p=inputParser; addParameter(p,'n_feat',36); parse(p,varargin{:});
nft=p.Results.n_feat; n_mol=obj.n_molecules;
n1=obj.n1; n2=obj.n2; ng=n1*n2;
ang1=(0:n1-1)*2*pi/n1; ang2=(0:n2-1)*2*pi/n2;
pos0=obj.coords{1}-mean(obj.coords{1},1);
feat0=obj.build_features(pos0,obj.atomic_numbers{1},ang1,ang2,nft);
nf=size(feat0,1); obj.n_feat_per_rot=nf;
X=zeros(n_mol,nf,ng); t0=tic;
fprintf('Computing product features: %d mol x %d feat x %d group...\n',n_mol,nf,ng);
for mol=1:n_mol
pos=obj.coords{mol}-mean(obj.coords{mol},1);
X(mol,:,:)=obj.build_features(pos,obj.atomic_numbers{mol},ang1,ang2,nft);
if mod(mol,500)==0, fprintf(' %d/%d (%.1fs)\n',mol,n_mol,toc(t0)); end
end
obj.X_tensor=X;
fprintf('Feature tensor: %d x %d x %d (%.1fs)\n',size(X),toc(t0));
obj.verify_product_structure();
end
function F = build_features(obj, coords, Z, ang1, ang2, nft)
% Build feature vector for each (g1,g2) pair.
% Features are arranged so that:
% - Angular features shift cyclically with g1 (x,y projections)
% - Axial features shift cyclically with g2 (z periodic embedding)
% - Coupled features shift in both dimensions
% - Invariant features are constant
n1=length(ang1); n2=length(ang2); ng=n1*n2;
na=size(coords,1); Zn=Z(:); w=Zn/(sum(Zn)+1e-10);
L=obj.L_axial;
feat_list={};
% ---- ANGULAR features (depend on g1 only) ----
% For each g1: project onto rotating xy-basis
% Replicated across g2 (constant in axial dimension)
for g1=1:n1
ca=cos(ang1(g1)); sa=sin(ang1(g1));
p1_g1=coords(:,1)*ca+coords(:,2)*sa;
ang_vals(g1,:)=[w'*p1_g1, Zn'*p1_g1, w'*(p1_g1.^2)];
end
% ang_vals is n1 x 3. Replicate across g2.
for f=1:3
row=zeros(1,ng);
for g2=1:n2
row((1:n1)+(g2-1)*n1)=ang_vals(:,f)';
end
feat_list{end+1}=row;
end
% ---- AXIAL features (depend on g2 only) ----
% Periodic embedding of z-coordinates
for g2=1:n2
ax_vals(g2,:)=[w'*cos(2*pi*coords(:,3)/L-ang2(g2)), ...
Zn'*cos(2*pi*coords(:,3)/L-ang2(g2)), ...
w'*sin(2*pi*coords(:,3)/L-ang2(g2))];
end
% ax_vals is n2 x 3. Replicate across g1.
for f=1:3
row=zeros(1,ng);
for g2=1:n2
row((1:n1)+(g2-1)*n1)=ax_vals(g2,f);
end
feat_list{end+1}=row;
end
% ---- COUPLED features (depend on both g1 AND g2) ----
% Product of angular and axial signals
for g1=1:n1
ca=cos(ang1(g1)); sa=sin(ang1(g1));
p1_g1=coords(:,1)*ca+coords(:,2)*sa;
for g2=1:n2
az_g2=cos(2*pi*coords(:,3)/L-ang2(g2));
% Linear index in product group
g_lin=(g2-1)*n1+g1;
coupled_vals(g1,g2,:)=[w'*(p1_g1.*az_g2), ...
Zn'*(p1_g1.*az_g2), ...
w'*(p1_g1.^2.*az_g2), ...
w'*(p1_g1.*az_g2.^2)];
end
end
% coupled_vals is n1 x n2 x 4
for f=1:4
row=zeros(1,ng);
for g2=1:n2
row((1:n1)+(g2-1)*n1)=coupled_vals(:,g2,f)';
end
feat_list{end+1}=row;
end
% ---- INVARIANT features (constant across all g) ----
if na>=2
D=pdist(coords);
for v=[mean(D),std(D),min(D),max(D)]
feat_list{end+1}=repmat(v,1,ng);
end
else
for k=1:4, feat_list{end+1}=zeros(1,ng); end
end
feat_list{end+1}=repmat(sum(Zn.^2)/100,1,ng);
feat_list{end+1}=repmat(mean(Zn),1,ng);
feat_list{end+1}=repmat(na,1,ng);
% Higher-order angular (for richer features)
for g1=1:n1
ca=cos(ang1(g1)); sa=sin(ang1(g1));
p1_g1=coords(:,1)*ca+coords(:,2)*sa;
p2_g1=-coords(:,1)*sa+coords(:,2)*ca;
ang_ho(g1,:)=[w'*(p1_g1.^3), w'*(p1_g1.*p2_g1), Zn'*(p1_g1.^2)];
end
for f=1:3
row=zeros(1,ng);
for g2=1:n2
row((1:n1)+(g2-1)*n1)=ang_ho(:,f)';
end
feat_list{end+1}=row;
end
% Assemble
F=cell2mat(feat_list');
nr=size(F,1);
if nr<nft, F=[F;zeros(nft-nr,ng)];
elseif nr>nft, F=F(1:nft,:); end
end
function verify_product_structure(obj)
if obj.n_molecules<1, return; end
mi=min(3,obj.n_molecules);
pos=obj.coords{mi}-mean(obj.coords{mi},1);
Z=obj.atomic_numbers{mi};
ang1=(0:obj.n1-1)*2*pi/obj.n1; ang2=(0:obj.n2-1)*2*pi/obj.n2;
F0=obj.build_features(pos,Z,ang1,ang2,obj.n_feat_per_rot);
F0_3d=reshape(F0,[],obj.n1,obj.n2);
% Test g1-shift: rotate about z by 2pi/n1
th1=2*pi/obj.n1;
R1=[cos(th1),-sin(th1),0;sin(th1),cos(th1),0;0,0,1];
Fr1=obj.build_features((R1*pos')',Z,ang1,ang2,obj.n_feat_per_rot);
Fr1_3d=reshape(Fr1,[],obj.n1,obj.n2);
err1=norm(Fr1_3d(:)-reshape(circshift(F0_3d,1,2),[],1))/(norm(F0_3d(:))+1e-20);
% Test g2-shift: translate z by L/n2
pos_tz=pos; pos_tz(:,3)=pos_tz(:,3)+obj.L_axial/obj.n2;
Ftz=obj.build_features(pos_tz,Z,ang1,ang2,obj.n_feat_per_rot);
Ftz_3d=reshape(Ftz,[],obj.n1,obj.n2);
err2=norm(Ftz_3d(:)-reshape(circshift(F0_3d,1,3),[],1))/(norm(F0_3d(:))+1e-20);
fprintf('Product structure check:\n');
fprintf(' G1 shift (z-rotation): %.2e %s\n', err1, tif(err1<1e-10,'PASS','WARN'));
fprintf(' G2 shift (z-translation): %.2e %s\n', err2, tif(err2<1e-10,'PASS','WARN'));
end
%% FOLD DATA ======================================================
function X_sub = fold_to_factor(obj, X, factor)
[ns,nf,~]=size(X);
X_3d=reshape(X,ns,nf,obj.n1,obj.n2);
if factor==1, X_sub=squeeze(mean(X_3d,4)); % keep G1
else, X_sub=squeeze(mean(X_3d,3)); end % keep G2
end
%% COMPARISON =====================================================
function results = run_comparison(obj, target_col, varargin)
p=inputParser; addParameter(p,'n_seeds',3); addParameter(p,'split',[0.7,0.15,0.15]);
parse(p,varargin{:});
n_seeds=p.Results.n_seeds; sp=p.Results.split;
y=obj.properties_mat(:,target_col); n=obj.n_molecules;
mn={'G1xG2_Ridge','G1xG2_MLP','G1_only_Ridge','G2_only_Ridge', ...
'Z_n_Ridge','Standard_MLP','Invariant_MLP','Augmented_MLP'};
nm=length(mn);
R2_all=zeros(nm,n_seeds); RMSE_all=zeros(nm,n_seeds); params=zeros(nm,1);
X_g1=obj.fold_to_factor(obj.X_tensor,1);
X_g2=obj.fold_to_factor(obj.X_tensor,2);
G_cyc=StarGAlgebra('cyclic',obj.n1*obj.n2);
fprintf('\n============================================================\n');
fprintf(' Product Group: Z_%d x Z_%d, %d mol, %d seeds\n',obj.n1,obj.n2,n,n_seeds);
fprintf('============================================================\n');
for seed=1:n_seeds
fprintf('\n, Seed %d/%d , \n',seed,n_seeds);
rng(seed*111); idx=randperm(n);
ntr=round(sp(1)*n); nva=round(sp(2)*n);
tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
ytr=y(tri); yva=y(vai); yte=y(tei);
Xtr=obj.X_tensor(tri,:,:); Xva=obj.X_tensor(vai,:,:); Xte=obj.X_tensor(tei,:,:);
% 1. G1xG2 + Ridge
t0=tic; fprintf(' [1] G1xG2 Ridge...');
[ftr,np]=extractStarGFeatures(Xtr,obj.G_prod,obj.n_feat_per_rot);
fva=extractStarGFeatures(Xva,obj.G_prod,obj.n_feat_per_rot,np);
fte=extractStarGFeatures(Xte,obj.G_prod,obj.n_feat_per_rot,np);
[w,~]=obj.ridge_cv(ftr,ytr,fva,yva); yp=fte*w;
R2_all(1,seed)=obj.r2(yte,yp); RMSE_all(1,seed)=sqrt(mean((yte-yp).^2));
params(1)=numel(w);
fprintf(' R2=%.4f (%.1fs)\n',R2_all(1,seed),toc(t0));
% 2. G1xG2 + MLP
t0=tic; fprintf(' [2] G1xG2 MLP...');
[wm,bm]=obj.train_mlp(ftr,ytr,fva,yva,[size(ftr,2),64,32,1],300,0.003);
yp2=obj.predict_mlp(fte,wm,bm);
R2_all(2,seed)=obj.r2(yte,yp2); RMSE_all(2,seed)=sqrt(mean((yte-yp2).^2));
params(2)=sum(cellfun(@numel,wm))+sum(cellfun(@numel,bm));
fprintf(' R2=%.4f (%.1fs)\n',R2_all(2,seed),toc(t0));
% 3. G1 only
t0=tic; fprintf(' [3] G1 only (Z_%d)...',obj.n1);
[f1,np1]=extractStarGFeatures(X_g1(tri,:,:),obj.G1,obj.n_feat_per_rot);
fv1=extractStarGFeatures(X_g1(vai,:,:),obj.G1,obj.n_feat_per_rot,np1);
ft1=extractStarGFeatures(X_g1(tei,:,:),obj.G1,obj.n_feat_per_rot,np1);
[w1,~]=obj.ridge_cv(f1,ytr,fv1,yva); yp1=ft1*w1;
R2_all(3,seed)=obj.r2(yte,yp1); RMSE_all(3,seed)=sqrt(mean((yte-yp1).^2));
params(3)=numel(w1);
fprintf(' R2=%.4f (%.1fs)\n',R2_all(3,seed),toc(t0));
% 4. G2 only
t0=tic; fprintf(' [4] G2 only (Z_%d)...',obj.n2);
[f2,np2]=extractStarGFeatures(X_g2(tri,:,:),obj.G2,obj.n_feat_per_rot);
fv2=extractStarGFeatures(X_g2(vai,:,:),obj.G2,obj.n_feat_per_rot,np2);
ft2=extractStarGFeatures(X_g2(tei,:,:),obj.G2,obj.n_feat_per_rot,np2);
[w2,~]=obj.ridge_cv(f2,ytr,fv2,yva); yp2f=ft2*w2;
R2_all(4,seed)=obj.r2(yte,yp2f); RMSE_all(4,seed)=sqrt(mean((yte-yp2f).^2));
params(4)=numel(w2);
fprintf(' R2=%.4f (%.1fs)\n',R2_all(4,seed),toc(t0));
% 5. Z_{n1*n2} cyclic (wrong structure)
t0=tic; fprintf(' [5] Z_%d cyclic...',obj.n1*obj.n2);
[fc,npc]=extractStarGFeatures(Xtr,G_cyc,obj.n_feat_per_rot);
fvc=extractStarGFeatures(Xva,G_cyc,obj.n_feat_per_rot,npc);
ftc=extractStarGFeatures(Xte,G_cyc,obj.n_feat_per_rot,npc);
[wc,~]=obj.ridge_cv(fc,ytr,fvc,yva); ypc=ftc*wc;
R2_all(5,seed)=obj.r2(yte,ypc); RMSE_all(5,seed)=sqrt(mean((yte-ypc).^2));
params(5)=numel(wc);
fprintf(' R2=%.4f (%.1fs)\n',R2_all(5,seed),toc(t0));
% 6. Standard MLP
t0=tic; fprintf(' [6] Standard MLP...');
Xs=squeeze(Xtr(:,:,1)); nfs=size(Xs,2);
[mu6,s6]=deal(mean(Xs),std(Xs)+1e-8);
[w6,b6]=obj.train_mlp((Xs-mu6)./s6,ytr,(squeeze(Xva(:,:,1))-mu6)./s6,yva,[nfs,64,32,1],300,0.003);
yp6=obj.predict_mlp((squeeze(Xte(:,:,1))-mu6)./s6,w6,b6);
R2_all(6,seed)=obj.r2(yte,yp6); RMSE_all(6,seed)=sqrt(mean((yte-yp6).^2));
params(6)=sum(cellfun(@numel,w6))+sum(cellfun(@numel,b6));
fprintf(' R2=%.4f (%.1fs)\n',R2_all(6,seed),toc(t0));
% 7. Invariant MLP
t0=tic; fprintf(' [7] Invariant MLP...');
fi_tr=[mean(Xtr,3),std(Xtr,0,3)];
fi_va=[mean(Xva,3),std(Xva,0,3)];
fi_te=[mean(Xte,3),std(Xte,0,3)];
[mu7,s7]=deal(mean(fi_tr),std(fi_tr)+1e-8);
[w7,b7]=obj.train_mlp((fi_tr-mu7)./s7,ytr,(fi_va-mu7)./s7,yva,[size(fi_tr,2),64,32,1],300,0.003);
yp7=obj.predict_mlp((fi_te-mu7)./s7,w7,b7);
R2_all(7,seed)=obj.r2(yte,yp7); RMSE_all(7,seed)=sqrt(mean((yte-yp7).^2));
params(7)=sum(cellfun(@numel,w7))+sum(cellfun(@numel,b7));
fprintf(' R2=%.4f (%.1fs)\n',R2_all(7,seed),toc(t0));
% 8. Augmented MLP
t0=tic; fprintf(' [8] Augmented MLP...');
nfa=obj.n_feat_per_rot; ng=obj.n1*obj.n2;
Xa=reshape(permute(Xtr,[1,3,2]),[],nfa); ya=repmat(ytr,ng,1);
Xav=reshape(permute(Xva,[1,3,2]),[],nfa); yav=repmat(yva,ng,1);
[mu8,s8]=deal(mean(Xa),std(Xa)+1e-8);
[w8,b8]=obj.train_mlp((Xa-mu8)./s8,ya,(Xav-mu8)./s8,yav,[nfa,64,32,1],200,0.003);
yp8=obj.predict_mlp((squeeze(Xte(:,:,1))-mu8)./s8,w8,b8);
R2_all(8,seed)=obj.r2(yte,yp8); RMSE_all(8,seed)=sqrt(mean((yte-yp8).^2));
params(8)=sum(cellfun(@numel,w8))+sum(cellfun(@numel,b8));
fprintf(' R2=%.4f (%.1fs)\n',R2_all(8,seed),toc(t0));
end
fprintf('\n============================================================\n');
fprintf(' Product Group Results: Z_%d x Z_%d (%d seeds)\n',obj.n1,obj.n2,n_seeds);
fprintf('============================================================\n');
fprintf('%-22s %12s %12s %8s\n','Method','Test R2','RMSE','Params');
fprintf('%s\n',repmat('-',1,56));
for m=1:nm
fprintf('%-22s %5.3f+/-%5.3f %6.4f+/-%6.4f %8d\n', ...
mn{m},mean(R2_all(m,:)),std(R2_all(m,:)),mean(RMSE_all(m,:)),std(RMSE_all(m,:)),params(m));
end
results.method_names=mn; results.R2=R2_all; results.RMSE=RMSE_all; results.params=params;
obj.results=results;
end
%% FACTORIZATION DISCOVERY ========================================
function report = discover_factorization(obj, target_col)
y=obj.properties_mat(:,target_col); ng=obj.n1*obj.n2;
fprintf('\n============================================================\n');
fprintf(' Factorization Discovery: n=%d\n',ng);
fprintf('============================================================\n');
factors=[];
for a=2:ng, b=ng/a;
if b>=2 && b==floor(b) && a<=b, factors=[factors;a,b]; end
end
n=obj.n_molecules; rng(0); idx=randperm(n);
ntr=round(0.5*n); nva=round(0.2*n);
tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
ytr=y(tri); yva=y(vai); yte=y(tei);
report=struct('name',{},'r2',{});
for f=1:size(factors,1)
a=factors(f,1); b=factors(f,2);
name=sprintf('Z_%d x Z_%d',a,b);
try
Gab=StarGAlgebra('product',{StarGAlgebra('cyclic',a),StarGAlgebra('cyclic',b)});
[ftr,np]=extractStarGFeatures(obj.X_tensor(tri,:,:),Gab,obj.n_feat_per_rot);
fva=extractStarGFeatures(obj.X_tensor(vai,:,:),Gab,obj.n_feat_per_rot,np);
fte=extractStarGFeatures(obj.X_tensor(tei,:,:),Gab,obj.n_feat_per_rot,np);
[w,~]=obj.ridge_cv(ftr,ytr,fva,yva);
r2=obj.r2(yte,fte*w);
catch, r2=NaN; end
report(end+1)=struct('name',name,'r2',r2);
fprintf(' %-15s R2=%+.4f\n',name,r2);
end
name=sprintf('Z_%d (cyclic)',ng);
try
Gc=StarGAlgebra('cyclic',ng);
[ftr,np]=extractStarGFeatures(obj.X_tensor(tri,:,:),Gc,obj.n_feat_per_rot);
fva=extractStarGFeatures(obj.X_tensor(vai,:,:),Gc,obj.n_feat_per_rot,np);
fte=extractStarGFeatures(obj.X_tensor(tei,:,:),Gc,obj.n_feat_per_rot,np);
[w,~]=obj.ridge_cv(ftr,ytr,fva,yva);
r2=obj.r2(yte,fte*w);
catch, r2=NaN; end
report(end+1)=struct('name',name,'r2',r2);
fprintf(' %-15s R2=%+.4f\n',name,r2);
[~,bi]=max([report.r2]);
fprintf('\n >>> Best: %s (R2=%.4f)\n',report(bi).name,report(bi).r2);
end
%% ML UTILITIES ===================================================
function [w,bl]=ridge_cv(~,Xr,yr,Xv,yv)
lams=[1e-3,0.01,0.1,1,10,100,1e3]; be=Inf; bl=0.01; pp=size(Xr,2);
R=eye(pp); R(1,1)=0;
for lam=lams, wt=(Xr'*Xr+lam*R)\(Xr'*yr);
e=mean((yv-Xv*wt).^2); if e<be,be=e;bl=lam;end;end
w=(Xr'*Xr+bl*R)\(Xr'*yr);
end
function [W,B]=train_mlp(~,X,y,Xv,yv,layers,maxep,lr)
nl=length(layers)-1; W=cell(nl,1); B=cell(nl,1);
for l=1:nl,fi=layers(l);W{l}=randn(fi,layers(l+1))*sqrt(2/fi);B{l}=zeros(1,layers(l+1));end
mW=cellfun(@(w)zeros(size(w)),W,'Uni',0);vW=mW;mB=cellfun(@(b)zeros(size(b)),B,'Uni',0);vB=mB;
b1=.9;b2=.999;ea=1e-8;bv=Inf;pa=20;wa=0;Wb=W;Bb=B;nn=size(X,1);bs=min(256,nn);
for ep=1:maxep,pm=randperm(nn);
for s=1:bs:nn,bi=pm(s:min(s+bs-1,nn));Xb=X(bi,:);yb=y(bi);
A=cell(nl+1,1);A{1}=Xb;
for l=1:nl,Zl=A{l}*W{l}+B{l};if l<nl,A{l+1}=max(0,Zl);else,A{l+1}=Zl;end;end
dZ=(A{nl+1}-yb)/size(Xb,1);
for l=nl:-1:1,gW=A{l}'*dZ;gB=sum(dZ,1);if l>1,dZ=(dZ*W{l}').*(A{l}>0);end
mW{l}=b1*mW{l}+(1-b1)*gW;vW{l}=b2*vW{l}+(1-b2)*gW.^2;
mB{l}=b1*mB{l}+(1-b1)*gB;vB{l}=b2*vB{l}+(1-b2)*gB.^2;
W{l}=W{l}-lr*(mW{l}/(1-b1^ep))./(sqrt(vW{l}/(1-b2^ep))+ea);
B{l}=B{l}-lr*(mB{l}/(1-b1^ep))./(sqrt(vB{l}/(1-b2^ep))+ea);end;end
yp=Xv;for l=1:nl,yp=yp*W{l}+B{l};if l<nl,yp=max(0,yp);end;end
vm=mean((yv-yp).^2);if vm<bv,bv=vm;wa=0;Wb=W;Bb=B;else,wa=wa+1;if wa>=pa,break;end;end;end
W=Wb;B=Bb;
end
function yp=predict_mlp(~,X,W,B),yp=X;for l=1:length(W),yp=yp*W{l}+B{l};if l<length(W),yp=max(0,yp);end;end;end
function v=r2(~,yt,yp),v=1-sum((yt-yp).^2)/(sum((yt-mean(yt)).^2)+1e-20);end
end
end
function s=tif(c,a,b), if c,s=a;else,s=b;end; end