%% ========================================================================
%% symmetry_discovery.m
%% Discover the group structure that best explains symmetry in data
%%
%% PERFORMANCE: Tucker products use n-mode matrix multiplications
%% instead of O(n^6) nested loops.
%%
%% LH & Claude 2026
%% ========================================================================
classdef symmetry_discovery < handle
properties
X; y; candidate_results; best_group; best_group_name; verbose
end
methods
function obj = symmetry_discovery(X, y)
obj.X = X;
if nargin >= 2 && ~isempty(y), obj.y = y(:); else, obj.y = []; end
obj.verbose = true;
end
function [best_G, report] = discover(obj, varargin)
p = inputParser;
addParameter(p,'max_order',0); addParameter(p,'test_dihedral',true);
addParameter(p,'test_symmetric',false); addParameter(p,'test_klein4',true);
addParameter(p,'test_quaternion',true); addParameter(p,'supervised_weight',0.4);
parse(p, varargin{:});
n_group = size(obj.X, 3);
max_ord = p.Results.max_order;
if max_ord == 0, max_ord = n_group; end
sw = p.Results.supervised_weight;
if isempty(obj.y), sw = 0; end
candidates = {};
divs = obj.divisors(n_group);
divs = divs(divs >= 2 & divs <= max_ord);
for d = divs(:)'
try, candidates{end+1} = {sprintf('Z_%d',d), StarGAlgebra('cyclic',d)}; end
end
if p.Results.test_dihedral
for d = divs(:)'
if mod(n_group,2*d)==0 && 2*d<=max_ord && d>=3
try, candidates{end+1} = {sprintf('D_%d',d), StarGAlgebra('dihedral',d)}; end
end
end
end
if p.Results.test_klein4 && mod(n_group,4)==0
try, candidates{end+1} = {'V_4', StarGAlgebra('klein4')}; end
end
if p.Results.test_quaternion && mod(n_group,8)==0
try, candidates{end+1} = {'Q_8', StarGAlgebra('quaternion')}; end
end
n_cand = length(candidates);
report = struct('name',{},'order',{},'svd_compress',{}, ...
'fourier_sparsity',{},'prediction_r2',{},'combined_score',{});
if obj.verbose
fprintf('\nTesting %d candidate groups (n_group = %d)\n', n_cand, n_group);
fprintf('%s\n', repmat('=',1,70));
end
for c = 1:n_cand
name = candidates{c}{1}; G = candidates{c}{2};
X_c = obj.adapt_data(G);
comp = obj.svd_compressibility(G, X_c);
spar = obj.fourier_sparsity(G, X_c);
pr2 = NaN;
if ~isempty(obj.y), pr2 = obj.prediction_score(G, X_c); end
if isnan(pr2) || sw == 0
combined = 0.5*comp + 0.5*spar;
else
combined = (1-sw)/2*comp + (1-sw)/2*spar + sw*max(pr2, 0);
end
sc.name = name; sc.order = G.n;
sc.svd_compress = comp; sc.fourier_sparsity = spar;
sc.prediction_r2 = pr2; sc.combined_score = combined;
report(c) = sc;
if obj.verbose
fprintf(' %-10s (order %2d): compress=%.3f sparsity=%.3f', name, G.n, comp, spar);
if ~isnan(pr2), fprintf(' R2=%+.3f', pr2); end
fprintf(' -> %.4f\n', combined);
end
end
all_s = [report.combined_score];
[~, bi] = max(all_s);
best_G = candidates{bi}{2};
obj.best_group = best_G;
obj.best_group_name = candidates{bi}{1};
obj.candidate_results = report;
if obj.verbose
fprintf('\n>>> Best group: %s (order %d), score = %.4f\n', ...
obj.best_group_name, best_G.n, all_s(bi));
end
end
%% SCORING ========================================================
function comp = svd_compressibility(~, G, X_c)
n_samp = min(200, size(X_c,1));
ratios = zeros(n_samp, 1);
for i = 1:n_samp
Xi = squeeze(X_c(i,:,:));
if G.isCyclic, Xi_hat = fft(Xi,[],2);
else
Xi_hat = zeros(size(Xi));
for j = 1:size(Xi,1), Xi_hat(j,:) = (G.F * Xi(j,:).').'; end
end
sv = svd(Xi_hat);
if length(sv) > 1
ratios(i) = sv(1)^2 / (sum(sv.^2) + 1e-20);
else, ratios(i) = 1; end
end
comp = mean(ratios);
end
function spar = fourier_sparsity(~, G, X_c)
n_samp = min(200, size(X_c,1));
ginis = zeros(n_samp, 1);
for i = 1:n_samp
Xi = squeeze(X_c(i,:,:));
if G.isCyclic, Xi_hat = fft(Xi,[],2);
else
Xi_hat = zeros(size(Xi));
for j = 1:size(Xi,1), Xi_hat(j,:) = (G.F * Xi(j,:).').'; end
end
vals = sort(abs(Xi_hat(:)));
nn = length(vals);
if nn < 2 || sum(vals) < 1e-20, ginis(i) = 0;
else, ginis(i) = 1 - 2*sum(cumsum(vals))/(nn*sum(vals)) + 1/nn; end
end
spar = mean(ginis);
end
function r2 = prediction_score(obj, G, X_c)
try
n = size(X_c,1); nf = size(X_c,2);
rng(0,'twister'); idx = randperm(n);
ntr = round(0.5*n); nva = round(0.2*n);
tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);
[ftr,np] = extractStarGFeatures(X_c(tri,:,:), G, nf);
fva = extractStarGFeatures(X_c(vai,:,:), G, nf, np);
fte = extractStarGFeatures(X_c(tei,:,:), G, nf, np);
pp = size(ftr,2); R = eye(pp); R(1,1) = 0;
lams = [1e-4,1e-3,0.01,0.1,1,10]; be = Inf; bl = 0.01;
for lam = lams
w = (ftr'*ftr+lam*R)\(ftr'*obj.y(tri));
e = mean((obj.y(vai)-fva*w).^2);
if e < be, be = e; bl = lam; end
end
w = (ftr'*ftr+bl*R)\(ftr'*obj.y(tri));
yp = fte * w;
yte = obj.y(tei);
r2 = 1 - sum((yte-yp).^2)/(sum((yte-mean(yte)).^2)+1e-20);
catch ME
if obj.verbose, fprintf(' [pred ERROR: %s]\n', ME.message); end
r2 = NaN;
end
end
%% DATA ADAPTATION ================================================
function X_c = adapt_data(obj, G)
nd = size(obj.X,3); ng = G.n;
if nd == ng, X_c = obj.X;
elseif mod(nd, ng) == 0
fold = nd/ng; X_c = zeros(size(obj.X,1),size(obj.X,2),ng);
for k=1:ng, X_c(:,:,k)=mean(obj.X(:,:,(k-1)*fold+1:k*fold),3); end
else
X_c = zeros(size(obj.X,1),size(obj.X,2),ng);
oi = linspace(1,nd,ng);
for s=1:size(obj.X,1), for f=1:size(obj.X,2)
X_c(s,f,:)=interp1(1:nd,squeeze(obj.X(s,f,:)),oi,'linear');
end; end
end
end
%% LEARNED ALGEBRA (vectorized Tucker products) ===================
function [F_learned, C_learned, err] = learn_algebra(obj, varargin)
pa = inputParser;
addParameter(pa,'max_iter',100); addParameter(pa,'lambda',0.02);
addParameter(pa,'tol',1e-6); parse(pa, varargin{:});
n = size(obj.X, 3);
T_est = obj.estimate_convolution_tensor();
F = fft(eye(n)) / sqrt(n);
C = zeros(n,n,n); for k=1:n, C(k,k,k)=1; end
prev_err = Inf;
for iter = 1:pa.Results.max_iter
C = obj.solve_core_vec(F, T_est, pa.Results.lambda);
F = obj.update_fourier_vec(F, C, T_est);
T_rec = nmode(nmode(nmode(C, F, 1), F, 2), conj(F), 3);
err = norm(T_est(:)-T_rec(:)) / (norm(T_est(:))+1e-20);
if obj.verbose && mod(iter,25)==0
fprintf(' learn_algebra iter %3d: error = %.6f\n', iter, err);
end
if abs(prev_err-err) < pa.Results.tol
if obj.verbose, fprintf(' Converged at iter %d (err=%.2e)\n', iter, err); end
break;
end
prev_err = err;
end
F_learned = F; C_learned = C;
if obj.verbose
nnz_c = sum(abs(C(:)) > 1e-6);
fprintf(' Learned C: %d/%d nonzero (%.1f%% sparse)\n', nnz_c, numel(C), 100*(1-nnz_c/numel(C)));
end
end
function T = estimate_convolution_tensor(obj)
n = size(obj.X,3); ns = min(500, size(obj.X,1));
T = zeros(n,n,n);
for s = 1:ns
v = mean(squeeze(obj.X(s,:,:)), 1); % 1 x n
T = T + reshape(kron(v', kron(v', v')), [n,n,n]);
end
T = T/ns; T = T - mean(T(:)); T = T/(max(abs(T(:)))+1e-20);
end
function C = solve_core_vec(~, F, T, lambda)
% C = T x_1 F^{-1} x_2 F^{-1} x_3 F^H (vectorized)
Fi = inv(F);
C = nmode(nmode(nmode(T, Fi, 1), Fi, 2), F', 3);
thr = lambda * max(abs(C(:)));
C = sign(C) .* max(abs(C) - thr, 0);
end
function F = update_fourier_vec(~, F, C, T)
lr = 0.005;
T_rec = nmode(nmode(nmode(C, F, 1), F, 2), conj(F), 3);
res = T_rec - T;
% Gradient w.r.t. F: dL/dF(a,i) = sum_{b,c,j,k} res(a,b,c)*C(i,j,k)*F(b,j)*conj(F(c,k))
% = sum_i [ res_mode1 * (C contracted with F on modes 2,3) ]_mode1
% Vectorized: grad = unfold(res,1)' * unfold(nmode(nmode(C,F,2),conj(F),3), 1)' ... tricky
% Simpler: grad(a,i) = res_a . Q_i where res_a = res(a,:,:) and Q_i = C(i,:,:) x_2 F x_3 conj(F)
n = size(F, 1);
Q = nmode(nmode(C, F, 2), conj(F), 3); % n x n x n
% grad = res_unf1 * Q_unf1'
res_unf = reshape(res, n, []); % n x n^2
Q_unf = reshape(Q, n, []); % n x n^2
grad = res_unf * Q_unf'; % n x n
F = F - lr * grad;
[U,~,V] = svd(F); F = U*V'; % project to unitary
end
function d = divisors(~, n)
d = []; for k=1:n, if mod(n,k)==0, d=[d,k]; end; end
end
end
end
%% ========================================================================
%% N-MODE PRODUCT (standalone function, used by learn_algebra)
%% T x_mode M = contract mode-th index of T with columns of M
%% ========================================================================
function T_out = nmode(T, M, mode)
sz = size(T);
nd = length(sz);
% Move target mode to position 1
order = [mode, setdiff(1:nd, mode)];
T_perm = permute(T, order);
T_unf = reshape(T_perm, sz(mode), []);
% Multiply
T_unf_out = M * T_unf;
% Reshape and permute back
new_sz = sz; new_sz(mode) = size(M, 1);
T_perm_out = reshape(T_unf_out, [new_sz(mode), new_sz(order(2:end))]);
% Inverse permutation
inv_order(order) = 1:nd;
T_out = permute(T_perm_out, inv_order);
end