%% ========================================================================
%% symmetry_discovery.m
%% Discover the group structure that best explains symmetry in data
%%
%% PERFORMANCE: Tucker products use n-mode matrix multiplications
%% instead of O(n^6) nested loops.
%%
%% LH & Claude 2026
%% ========================================================================

classdef symmetry_discovery < handle
    properties
        X; y; candidate_results; best_group; best_group_name; verbose
    end

    methods
        function obj = symmetry_discovery(X, y)
            obj.X = X;
            if nargin >= 2 && ~isempty(y), obj.y = y(:); else, obj.y = []; end
            obj.verbose = true;
        end

        function [best_G, report] = discover(obj, varargin)
            p = inputParser;
            addParameter(p,'max_order',0); addParameter(p,'test_dihedral',true);
            addParameter(p,'test_symmetric',false); addParameter(p,'test_klein4',true);
            addParameter(p,'test_quaternion',true); addParameter(p,'supervised_weight',0.4);
            parse(p, varargin{:});

            n_group = size(obj.X, 3);
            max_ord = p.Results.max_order;
            if max_ord == 0, max_ord = n_group; end
            sw = p.Results.supervised_weight;
            if isempty(obj.y), sw = 0; end

            candidates = {};
            divs = obj.divisors(n_group);
            divs = divs(divs >= 2 & divs <= max_ord);
            for d = divs(:)'
                try, candidates{end+1} = {sprintf('Z_%d',d), StarGAlgebra('cyclic',d)}; end
            end
            if p.Results.test_dihedral
                for d = divs(:)'
                    if mod(n_group,2*d)==0 && 2*d<=max_ord && d>=3
                        try, candidates{end+1} = {sprintf('D_%d',d), StarGAlgebra('dihedral',d)}; end
                    end
                end
            end
            if p.Results.test_klein4 && mod(n_group,4)==0
                try, candidates{end+1} = {'V_4', StarGAlgebra('klein4')}; end
            end
            if p.Results.test_quaternion && mod(n_group,8)==0
                try, candidates{end+1} = {'Q_8', StarGAlgebra('quaternion')}; end
            end

            n_cand = length(candidates);
            report = struct('name',{},'order',{},'svd_compress',{}, ...
                'fourier_sparsity',{},'prediction_r2',{},'combined_score',{});

            if obj.verbose
                fprintf('\nTesting %d candidate groups (n_group = %d)\n', n_cand, n_group);
                fprintf('%s\n', repmat('=',1,70));
            end

            for c = 1:n_cand
                name = candidates{c}{1}; G = candidates{c}{2};
                X_c = obj.adapt_data(G);

                comp = obj.svd_compressibility(G, X_c);
                spar = obj.fourier_sparsity(G, X_c);
                pr2  = NaN;
                if ~isempty(obj.y), pr2 = obj.prediction_score(G, X_c); end

                if isnan(pr2) || sw == 0
                    combined = 0.5*comp + 0.5*spar;
                else
                    combined = (1-sw)/2*comp + (1-sw)/2*spar + sw*max(pr2, 0);
                end

                sc.name = name; sc.order = G.n;
                sc.svd_compress = comp; sc.fourier_sparsity = spar;
                sc.prediction_r2 = pr2; sc.combined_score = combined;
                report(c) = sc;

                if obj.verbose
                    fprintf('  %-10s (order %2d): compress=%.3f  sparsity=%.3f', name, G.n, comp, spar);
                    if ~isnan(pr2), fprintf('  R2=%+.3f', pr2); end
                    fprintf('  -> %.4f\n', combined);
                end
            end

            all_s = [report.combined_score];
            [~, bi] = max(all_s);
            best_G = candidates{bi}{2};
            obj.best_group = best_G;
            obj.best_group_name = candidates{bi}{1};
            obj.candidate_results = report;
            if obj.verbose
                fprintf('\n>>> Best group: %s (order %d), score = %.4f\n', ...
                    obj.best_group_name, best_G.n, all_s(bi));
            end
        end

        %% SCORING ========================================================
        function comp = svd_compressibility(~, G, X_c)
            n_samp = min(200, size(X_c,1));
            ratios = zeros(n_samp, 1);
            for i = 1:n_samp
                Xi = squeeze(X_c(i,:,:));
                if G.isCyclic, Xi_hat = fft(Xi,[],2);
                else
                    Xi_hat = zeros(size(Xi));
                    for j = 1:size(Xi,1), Xi_hat(j,:) = (G.F * Xi(j,:).').'; end
                end
                sv = svd(Xi_hat);
                if length(sv) > 1
                    ratios(i) = sv(1)^2 / (sum(sv.^2) + 1e-20);
                else, ratios(i) = 1; end
            end
            comp = mean(ratios);
        end

        function spar = fourier_sparsity(~, G, X_c)
            n_samp = min(200, size(X_c,1));
            ginis = zeros(n_samp, 1);
            for i = 1:n_samp
                Xi = squeeze(X_c(i,:,:));
                if G.isCyclic, Xi_hat = fft(Xi,[],2);
                else
                    Xi_hat = zeros(size(Xi));
                    for j = 1:size(Xi,1), Xi_hat(j,:) = (G.F * Xi(j,:).').'; end
                end
                vals = sort(abs(Xi_hat(:)));
                nn = length(vals);
                if nn < 2 || sum(vals) < 1e-20, ginis(i) = 0;
                else, ginis(i) = 1 - 2*sum(cumsum(vals))/(nn*sum(vals)) + 1/nn; end
            end
            spar = mean(ginis);
        end

        function r2 = prediction_score(obj, G, X_c)
            try
                n = size(X_c,1); nf = size(X_c,2);
                rng(0,'twister'); idx = randperm(n);
                ntr = round(0.5*n); nva = round(0.2*n);
                tri=idx(1:ntr); vai=idx(ntr+1:ntr+nva); tei=idx(ntr+nva+1:end);

                [ftr,np] = extractStarGFeatures(X_c(tri,:,:), G, nf);
                fva = extractStarGFeatures(X_c(vai,:,:), G, nf, np);
                fte = extractStarGFeatures(X_c(tei,:,:), G, nf, np);

                pp = size(ftr,2); R = eye(pp); R(1,1) = 0;
                lams = [1e-4,1e-3,0.01,0.1,1,10]; be = Inf; bl = 0.01;
                for lam = lams
                    w = (ftr'*ftr+lam*R)\(ftr'*obj.y(tri));
                    e = mean((obj.y(vai)-fva*w).^2);
                    if e < be, be = e; bl = lam; end
                end
                w = (ftr'*ftr+bl*R)\(ftr'*obj.y(tri));
                yp = fte * w;
                yte = obj.y(tei);
                r2 = 1 - sum((yte-yp).^2)/(sum((yte-mean(yte)).^2)+1e-20);
            catch ME
                if obj.verbose, fprintf('    [pred ERROR: %s]\n', ME.message); end
                r2 = NaN;
            end
        end

        %% DATA ADAPTATION ================================================
        function X_c = adapt_data(obj, G)
            nd = size(obj.X,3); ng = G.n;
            if nd == ng, X_c = obj.X;
            elseif mod(nd, ng) == 0
                fold = nd/ng; X_c = zeros(size(obj.X,1),size(obj.X,2),ng);
                for k=1:ng, X_c(:,:,k)=mean(obj.X(:,:,(k-1)*fold+1:k*fold),3); end
            else
                X_c = zeros(size(obj.X,1),size(obj.X,2),ng);
                oi = linspace(1,nd,ng);
                for s=1:size(obj.X,1), for f=1:size(obj.X,2)
                    X_c(s,f,:)=interp1(1:nd,squeeze(obj.X(s,f,:)),oi,'linear');
                end; end
            end
        end

        %% LEARNED ALGEBRA (vectorized Tucker products) ===================
        function [F_learned, C_learned, err] = learn_algebra(obj, varargin)
            pa = inputParser;
            addParameter(pa,'max_iter',100); addParameter(pa,'lambda',0.02);
            addParameter(pa,'tol',1e-6); parse(pa, varargin{:});

            n = size(obj.X, 3);
            T_est = obj.estimate_convolution_tensor();

            F = fft(eye(n)) / sqrt(n);
            C = zeros(n,n,n); for k=1:n, C(k,k,k)=1; end

            prev_err = Inf;
            for iter = 1:pa.Results.max_iter
                C = obj.solve_core_vec(F, T_est, pa.Results.lambda);
                F = obj.update_fourier_vec(F, C, T_est);
                T_rec = nmode(nmode(nmode(C, F, 1), F, 2), conj(F), 3);
                err = norm(T_est(:)-T_rec(:)) / (norm(T_est(:))+1e-20);
                if obj.verbose && mod(iter,25)==0
                    fprintf('  learn_algebra iter %3d: error = %.6f\n', iter, err);
                end
                if abs(prev_err-err) < pa.Results.tol
                    if obj.verbose, fprintf('  Converged at iter %d (err=%.2e)\n', iter, err); end
                    break;
                end
                prev_err = err;
            end
            F_learned = F; C_learned = C;
            if obj.verbose
                nnz_c = sum(abs(C(:)) > 1e-6);
                fprintf('  Learned C: %d/%d nonzero (%.1f%% sparse)\n', nnz_c, numel(C), 100*(1-nnz_c/numel(C)));
            end
        end

        function T = estimate_convolution_tensor(obj)
            n = size(obj.X,3); ns = min(500, size(obj.X,1));
            T = zeros(n,n,n);
            for s = 1:ns
                v = mean(squeeze(obj.X(s,:,:)), 1);  % 1 x n
                T = T + reshape(kron(v', kron(v', v')), [n,n,n]);
            end
            T = T/ns; T = T - mean(T(:)); T = T/(max(abs(T(:)))+1e-20);
        end

        function C = solve_core_vec(~, F, T, lambda)
            % C = T x_1 F^{-1} x_2 F^{-1} x_3 F^H  (vectorized)
            Fi = inv(F);
            C = nmode(nmode(nmode(T, Fi, 1), Fi, 2), F', 3);
            thr = lambda * max(abs(C(:)));
            C = sign(C) .* max(abs(C) - thr, 0);
        end

        function F = update_fourier_vec(~, F, C, T)
            lr = 0.005;
            T_rec = nmode(nmode(nmode(C, F, 1), F, 2), conj(F), 3);
            res = T_rec - T;

            % Gradient w.r.t. F: dL/dF(a,i) = sum_{b,c,j,k} res(a,b,c)*C(i,j,k)*F(b,j)*conj(F(c,k))
            % = sum_i [ res_mode1 * (C contracted with F on modes 2,3) ]_mode1
            % Vectorized: grad = unfold(res,1)' * unfold(nmode(nmode(C,F,2),conj(F),3), 1)'  ... tricky
            % Simpler: grad(a,i) = res_a . Q_i  where res_a = res(a,:,:) and Q_i = C(i,:,:) x_2 F x_3 conj(F)
            n = size(F, 1);
            Q = nmode(nmode(C, F, 2), conj(F), 3);  % n x n x n
            % grad = res_unf1 * Q_unf1'
            res_unf = reshape(res, n, []);    % n x n^2
            Q_unf = reshape(Q, n, []);        % n x n^2
            grad = res_unf * Q_unf';          % n x n

            F = F - lr * grad;
            [U,~,V] = svd(F); F = U*V';  % project to unitary
        end

        function d = divisors(~, n)
            d = []; for k=1:n, if mod(n,k)==0, d=[d,k]; end; end
        end
    end
end

%% ========================================================================
%% N-MODE PRODUCT (standalone function, used by learn_algebra)
%% T x_mode M  = contract mode-th index of T with columns of M
%% ========================================================================
function T_out = nmode(T, M, mode)
    sz = size(T);
    nd = length(sz);
    % Move target mode to position 1
    order = [mode, setdiff(1:nd, mode)];
    T_perm = permute(T, order);
    T_unf = reshape(T_perm, sz(mode), []);
    % Multiply
    T_unf_out = M * T_unf;
    % Reshape and permute back
    new_sz = sz; new_sz(mode) = size(M, 1);
    T_perm_out = reshape(T_unf_out, [new_sz(mode), new_sz(order(2:end))]);
    % Inverse permutation
    inv_order(order) = 1:nd;
    T_out = permute(T_perm_out, inv_order);
end
