%% ============================================================================
%% GPU-ACCELERATED StarGAlgebra CLASS
%% ============================================================================
% LH & SU 2026
%% ============================================================================

classdef StarGAlgebra
    properties
        G
        n
        F
        Finv
        irrepDims
        convTensor
        invTable
        useParallel
        useGPU
        isAbelian
        isCyclic
        identityIdx

        % GPU arrays
        G_gpu
        convTensor_gpu
        invTable_gpu
        F_gpu
        Finv_gpu
    end

    methods
        function obj = StarGAlgebra(groupType, varargin)
            obj.useGPU = false;
            obj.useParallel = ~isempty(gcp('nocreate'));

            groupParam = [];
            if ~isempty(varargin)
                groupParam = varargin{1};
            end

            switch lower(groupType)
                case 'cyclic'
                    obj = obj.initCyclic(groupParam);
                case 'dihedral'
                    obj = obj.initDihedral(groupParam);
                case 'symmetric'
                    obj = obj.initSymmetric(groupParam);
                case 'klein4'
                    obj = obj.initKlein4();
                case 'octahedral'
                    obj = obj.initOctahedral();
                case 'quaternion'
                    obj = obj.initQuaternion();
                case 'product'
                    obj = obj.initProduct(groupParam);
                case 'table'
                    obj = obj.initFromTable(groupParam);
                otherwise
                    error('Unknown group type: %s', groupType);
            end

            obj = obj.buildConvolutionTensor();
            obj = obj.buildInverseTable();
            obj = obj.findIdentity();
            obj.isAbelian = obj.checkAbelian();
        end

        %% GPU Methods
        function obj = enableGPU(obj)
            if gpuDeviceCount > 0
                obj.useGPU = true;
                obj = obj.initGPUArrays();
                fprintf('GPU enabled: %s\n', gpuDevice().Name);
            else
                warning('No GPU available');
            end
        end

        function obj = initGPUArrays(obj)
            obj.G_gpu = gpuArray(int32(obj.G));
            obj.convTensor_gpu = gpuArray(obj.convTensor);
            obj.invTable_gpu = gpuArray(int32(obj.invTable));
            obj.F_gpu = gpuArray(obj.F);
            obj.Finv_gpu = gpuArray(obj.Finv);
        end

        function obj = disableGPU(obj)
            obj.useGPU = false;
            obj.G_gpu = [];
            obj.convTensor_gpu = [];
            obj.invTable_gpu = [];
            obj.F_gpu = [];
            obj.Finv_gpu = [];
        end

        %% Core Setup
        function obj = findIdentity(obj)
            for e = 1:obj.n
                isId = true;
                for a = 1:obj.n
                    if obj.G(e, a) ~= a || obj.G(a, e) ~= a
                        isId = false;
                        break;
                    end
                end
                if isId
                    obj.identityIdx = e;
                    return;
                end
            end
            error('No identity element found');
        end

        function obj = buildInverseTable(obj)
            obj.invTable = zeros(obj.n, 1);
            e = 1;
            for candidate = 1:obj.n
                if all(obj.G(candidate, :) == 1:obj.n)
                    e = candidate;
                    break;
                end
            end
            for a = 1:obj.n
                for b = 1:obj.n
                    if obj.G(a, b) == e
                        obj.invTable(a) = b;
                        break;
                    end
                end
            end
        end

        function obj = buildConvolutionTensor(obj)
            ng = obj.n;
            obj.convTensor = zeros(ng, ng, ng);
            for a = 1:ng
                for b = 1:ng
                    c = obj.G(a, b);
                    obj.convTensor(a, b, c) = 1;
                end
            end
        end

        function isAb = checkAbelian(obj)
            isAb = isequal(obj.G, obj.G');
        end

        %% Convolution Methods
        function c = convolve_direct(obj, a, b)
            a = a(:);
            b = b(:);

            if obj.useGPU
                a = gpuArray(a);
                b = gpuArray(b);
                T = obj.convTensor_gpu;
            else
                T = obj.convTensor;
            end

            ab = a * b';
            c = squeeze(sum(sum(T .* ab, 1), 2));

            if obj.useGPU
                c = gather(c);
            end
        end

        function c = convolve_inverse(obj, a, b)
            a = a(:);
            b = b(:);
            c = zeros(obj.n, 1);

            for c_idx = 1:obj.n
                for g = 1:obj.n
                    g_inv = obj.invTable(g);
                    g_inv_c = obj.G(g_inv, c_idx);
                    c(c_idx) = c(c_idx) + a(g) * b(g_inv_c);
                end
            end
        end

        function c = convolve(obj, a, b)
            a = a(:);
            b = b(:);

            if obj.isCyclic
                if obj.useGPU
                    a = gpuArray(a);
                    b = gpuArray(b);
                    c = gather(ifft(fft(a) .* fft(b)));
                else
                    c = ifft(fft(a) .* fft(b));
                end
                if isreal(a) && isreal(b)
                    c = real(c);
                end
            else
                c = obj.convolve_direct(a, b);
            end
        end

        %% Star-G Methods
        function C = starG_direct(obj, A, B)
            [l, m, ng] = size(A);
            [~, p, ~] = size(B);

            if obj.useGPU
                A = gpuArray(A);
                B = gpuArray(B);
                T = obj.convTensor_gpu;
                C = gpuArray.zeros(l, p, ng);
            else
                T = obj.convTensor;
                C = zeros(l, p, ng);
            end

            for i = 1:l
                for j = 1:p
                    for k = 1:m
                        a_ik = squeeze(A(i, k, :));
                        b_kj = squeeze(B(k, j, :));
                        ab = a_ik(:) * b_kj(:)';
                        conv_result = squeeze(sum(sum(T .* ab, 1), 2));
                        C(i, j, :) = C(i, j, :) + reshape(conv_result, [1, 1, ng]);
                    end
                end
            end

            if obj.useGPU
                C = gather(C);
            end
        end

        function C = starG_fourier(obj, A, B)
            [l, m, n] = size(A);
            [~, p, ~] = size(B);

            if obj.useGPU
                A = gpuArray(A);
                B = gpuArray(B);
            end

            Ahat = fft(A, [], 3);
            Bhat = fft(B, [], 3);

            if exist('pagemtimes', 'builtin') || exist('pagemtimes', 'file')
                Chat = pagemtimes(Ahat, Bhat);
            else
                if obj.useGPU
                    Chat = gpuArray.zeros(l, p, n);
                else
                    Chat = zeros(l, p, n);
                end
                for k = 1:n
                    Chat(:,:,k) = Ahat(:,:,k) * Bhat(:,:,k);
                end
            end

            C = ifft(Chat, [], 3);

            if obj.useGPU
                C = gather(C);
            end

            if isreal(A) && isreal(B)
                C = real(C);
            end
        end

        function C = starG(obj, A, B)
            szA = size(A);
            szB = size(B);

            if ndims(A) == 2
                A = reshape(A, [szA(1), szA(2), 1]);
                if size(A,3) == 1 && obj.n > 1
                    A = repmat(A, [1, 1, obj.n]);
                end
            end
            if ndims(B) == 2
                B = reshape(B, [szB(1), szB(2), 1]);
                if size(B,3) == 1 && obj.n > 1
                    B = repmat(B, [1, 1, obj.n]);
                end
            end

            [~, m, n_A] = size(A);
            [m2, ~, n_B] = size(B);

            assert(m == m2, 'Inner dimensions must match');
            assert(n_A == obj.n && n_B == obj.n, 'Mode-3 must equal group order');

            if obj.isCyclic
                C = obj.starG_fourier(A, B);
            else
                C = obj.starG_direct(A, B);
            end
        end

        %% Batch Operations
        function C = starG_batch(obj, A, B)
            [l, m, ng, batch] = size(A);
            [~, p, ~, ~] = size(B);

            if obj.useGPU && obj.isCyclic
                A = gpuArray(A);
                B = gpuArray(B);

                Ahat = fft(A, [], 3);
                Bhat = fft(B, [], 3);

                Chat = gpuArray.zeros(l, p, ng, batch);

                for b_idx = 1:batch
                    if exist('pagemtimes', 'builtin')
                        Chat(:,:,:,b_idx) = pagemtimes(Ahat(:,:,:,b_idx), Bhat(:,:,:,b_idx));
                    else
                        for k = 1:ng
                            Chat(:,:,k,b_idx) = Ahat(:,:,k,b_idx) * Bhat(:,:,k,b_idx);
                        end
                    end
                end

                C = gather(real(ifft(Chat, [], 3)));
            else
                C = zeros(l, p, ng, batch);
                for b_idx = 1:batch
                    C(:,:,:,b_idx) = obj.starG(A(:,:,:,b_idx), B(:,:,:,b_idx));
                end
            end
        end

        function C = starG_old(obj, A, B)
            % ★_G product - Convolutional tensor product
            % A: l x m x n, B: m x p x n -> C: l x p x n

            szA = size(A);
            szB = size(B);

            % Handle 2D inputs
            if ndims(A) == 2
                A = reshape(A, [szA(1), szA(2), 1]);
                if size(A,3) == 1 && obj.n > 1
                    A = repmat(A, [1, 1, obj.n]);
                end
            end
            if ndims(B) == 2
                B = reshape(B, [szB(1), szB(2), 1]);
                if size(B,3) == 1 && obj.n > 1
                    B = repmat(B, [1, 1, obj.n]);
                end
            end

            [l, m, ~] = size(A);
            [m2, p, ~] = size(B);
            n = obj.n;

            assert(m == m2, 'Inner dimensions must match');

            % Transform along mode-3 using group Fourier matrix
            % Reshape for matrix multiplication with F
            A_reshape = reshape(permute(A, [3, 1, 2]), n, []);  % n x (l*m)
            Ahat_reshape = obj.F' * A_reshape;  % n x (l*m)
            Ahat = permute(reshape(Ahat_reshape, n, l, m), [2, 3, 1]);  % l x m x n

            B_reshape = reshape(permute(B, [3, 1, 2]), n, []);  % n x (m*p)
            Bhat_reshape = obj.F' * B_reshape;  % n x (m*p)
            Bhat = permute(reshape(Bhat_reshape, n, m, p), [2, 3, 1]);  % m x p x n

            % Multiply in transform domain with Peter-Weyl structure
            % For each output position k, compute using convTensor
            Chat = zeros(l, p, n);

            % Vectorized computation using convTensor
            for k = 1:n
                Ck_slice = zeros(l, p);
                for x = 1:n
                    for y = 1:n
                        if obj.convTensor(x, y, k) ~= 0
                            Ck_slice = Ck_slice + obj.convTensor(x,y,k) * (Ahat(:,:,x) * Bhat(:,:,y));
                        end
                    end
                end
                Chat(:,:,k) = Ck_slice;
            end

            % Transform back
            Chat_reshape = reshape(permute(Chat, [3, 1, 2]), n, []);  % n x (l*p)
            C_reshape = obj.Finv' * Chat_reshape;  % n x (l*p)
            C = permute(reshape(C_reshape, n, l, p), [2, 3, 1]);  % l x p x n

            if isreal(A) && isreal(B)
                C = real(C);
            end
        end

        %% Conjugate Transpose
        function Ah = conjugateTranspose(obj, A)
            [m, p, n] = size(A);
            Ah = zeros(p, m, n);

            for i = 1:p
                for j = 1:m
                    for g = 1:n
                        g_inv = obj.invTable(g);
                        Ah(i, j, g) = conj(A(j, i, g_inv));
                    end
                end
            end
        end

        function Ah = conjugateTranspose_fast(obj, A)
            Ah = permute(conj(A), [2, 1, 3]);
            Ah = Ah(:, :, obj.invTable);
        end

        %% SVD
        function [U, S, V] = starG_SVD(obj, A)
            % Star_G-SVD via generalized Fourier transform.
            % Algorithm 1 from the theory notes:
            %   1. Transform to Fourier domain using G.F
            %   2. SVD each Fourier slice
            %   3. Transform back using G.Finv
            %
            % For cyclic groups, G.F = fft(eye(n)) so this is equivalent
            % to the FFT-based algorithm.  For dihedral, symmetric, etc.,
            % G.F is the generalized Fourier matrix.

            [l, m, n] = size(A);
            minlm = min(l, m);

            % --- Forward generalized Fourier transform along mode 3 ---
            if obj.isCyclic
                Ahat = fft(A, [], 3);  % fast path for cyclic
            else
                % General: multiply each tube by G.F
                A_r = reshape(permute(A, [3,1,2]), n, []);  % n x (l*m)
                Ahat_r = obj.F * A_r;                       % n x (l*m)
                Ahat = permute(reshape(Ahat_r, n, l, m), [2,3,1]);
            end

            % --- SVD of each Fourier slice ---
            Uhat = zeros(l, minlm, n);
            Shat = zeros(minlm, minlm, n);
            Vhat = zeros(m, minlm, n);

            for i = 1:n
                slice = Ahat(:,:,i);
                if obj.useGPU, slice = gather(slice); end
                [Ui, Si, Vi] = svd(slice, 'econ');
                k = size(Ui, 2);
                Uhat(:, 1:k, i) = Ui;
                Shat(1:k, 1:k, i) = Si;
                Vhat(:, 1:k, i) = Vi;
            end

            % --- Inverse generalized Fourier transform ---
            if obj.isCyclic
                U = ifft(Uhat, [], 3);
                S = ifft(Shat, [], 3);
                V = ifft(Vhat, [], 3);
            else
                for arr_cell = {{'U', Uhat, l, minlm}, ...
                                {'S', Shat, minlm, minlm}, ...
                                {'V', Vhat, m, minlm}}
                    c = arr_cell{1};
                    arr = c{2}; r1 = c{3}; r2 = c{4};
                    arr_r = reshape(permute(arr, [3,1,2]), n, []);
                    out_r = obj.Finv * arr_r;
                    switch c{1}
                        case 'U', U = permute(reshape(out_r, n, r1, r2), [2,3,1]);
                        case 'S', S = permute(reshape(out_r, n, r1, r2), [2,3,1]);
                        case 'V', V = permute(reshape(out_r, n, r1, r2), [2,3,1]);
                    end
                end
            end

            if isreal(A)
                U = real(U); S = real(S); V = real(V);
            end
        end

        function Ak = truncate(obj, A, k)
            [l, m, n] = size(A);
            k = min(k, min(l, m));

            % --- Forward generalized Fourier transform ---
            if obj.isCyclic
                Ahat = fft(A, [], 3);
            else
                A_r = reshape(permute(A, [3,1,2]), n, []);
                Ahat_r = obj.F * A_r;
                Ahat = permute(reshape(Ahat_r, n, l, m), [2,3,1]);
            end

            Akhat = zeros(l, m, n);
            for i = 1:n
                slice = Ahat(:,:,i);
                if obj.useGPU, slice = gather(slice); end
                [Ui, Si, Vi] = svd(slice, 'econ');
                ki = min(k, size(Si, 1));
                Akhat(:,:,i) = Ui(:,1:ki) * Si(1:ki,1:ki) * Vi(:,1:ki)';
            end

            % --- Inverse generalized Fourier transform ---
            if obj.isCyclic
                Ak = ifft(Akhat, [], 3);
            else
                Ak_r = reshape(permute(Akhat, [3,1,2]), n, []);
                Akinv_r = obj.Finv * Ak_r;
                Ak = permute(reshape(Akinv_r, n, l, m), [2,3,1]);
            end

            if isreal(A)
                Ak = real(Ak);
            end
        end

        %% Group Initialization
        function obj = initCyclic(obj, n)
            obj.n = n;
            obj.isCyclic = true;
            [I, J] = meshgrid(0:n-1, 0:n-1);
            obj.G = mod(I + J, n) + 1;
            obj.identityIdx = 1;
            obj.F = fft(eye(n));
            obj.Finv = conj(obj.F) / n;
        end

        function obj = initKlein4(obj)
            obj.n = 4;
            obj.isCyclic = false;
            obj.G = [1 2 3 4; 2 1 4 3; 3 4 1 2; 4 3 2 1];
            obj.identityIdx = 1;
            obj.F = [1 1 1 1; 1 1 -1 -1; 1 -1 1 -1; 1 -1 -1 1];
            obj.Finv = obj.F / 4;
        end

        function obj = initDihedral(obj, n)
            % Build D_n = <r, s | r^n = s^2 = 1, srs = r^{-1}>.
            % Multiplication-table convention: rotations 1..n, reflections
            % n+1..2n. Then the generalized Fourier matrix F has rows
            % concatenating the row-vectorization of the 1-d trivial, the
            % 1-d sign, the (n-1)/2 two-dimensional (or (n-2)/2 + two extra
            % 1-d for n even) irrep matrices.
            obj.n = 2*n;
            obj.isCyclic = false;
            obj.G = zeros(2*n, 2*n);

            for i = 1:2*n
                for j = 1:2*n
                    if i <= n && j <= n
                        obj.G(i,j) = mod(i-1 + j-1, n) + 1;
                    elseif i <= n && j > n
                        obj.G(i,j) = mod(j-n-1 + i-1, n) + n + 1;
                    elseif i > n && j <= n
                        obj.G(i,j) = mod(i-n-1 - (j-1), n) + n + 1;
                    else
                        obj.G(i,j) = mod((i-n-1) - (j-n-1), n) + 1;
                    end
                end
            end

            obj.identityIdx = 1;

            % Build F from D_n irreps. Mirrors python/large_scale/starg_torch/
            % algebra.py:_build_F_dihedral. Two 1-d irreps (trivial, sign),
            % (n-1)/2 two-d irreps (k = 1, .., floor((n-1)/2)). For even n,
            % two more 1-d irreps.
            order = 2*n;
            n_2d = floor((n - 1) / 2);
            irrep_dims_local = [1, 1, repmat(2, 1, n_2d)];
            if mod(n, 2) == 0
                irrep_dims_local = [irrep_dims_local, 1, 1];
            end
            rows = [];
            for g = 1:order
                r = mod(g - 1, n);
                s = floor((g - 1) / n);
                row = [];
                row = [row, 1];                                  % trivial
                if s == 0, row = [row, 1]; else, row = [row, -1]; end  % sign
                for k = 1:n_2d
                    ang = 2 * pi * k * r / n;
                    if s == 0
                        Mblk = [cos(ang), -sin(ang); sin(ang), cos(ang)];
                    else
                        Mblk = [cos(ang),  sin(ang); sin(ang), -cos(ang)];
                    end
                    row = [row, Mblk(1,1), Mblk(1,2), Mblk(2,1), Mblk(2,2)];
                end
                if mod(n, 2) == 0
                    if s == 0, sgn1 = 1; else, sgn1 = -1; end
                    if mod(r, 2) == 0, sgn2 = 1; else, sgn2 = -1; end
                    row = [row, sgn1 * sgn2, sgn2];
                end
                rows = [rows; row];
            end
            obj.F = complex(rows / sqrt(order));
            obj.Finv = inv(obj.F);
            obj.irrepDims = irrep_dims_local;
        end

        function obj = initSymmetric(obj, n)
            obj.isCyclic = false;
            perms_list = perms(1:n);
            obj.n = size(perms_list, 1);

            identity_perm = 1:n;
            identity_idx = find(all(perms_list == identity_perm, 2), 1);

            if identity_idx ~= 1
                perms_list([1, identity_idx], :) = perms_list([identity_idx, 1], :);
            end

            perm_to_idx = containers.Map('KeyType', 'char', 'ValueType', 'int32');
            for i = 1:obj.n
                perm_to_idx(mat2str(perms_list(i,:))) = i;
            end

            obj.G = zeros(obj.n, obj.n);
            for i = 1:obj.n
                for j = 1:obj.n
                    composed = perms_list(i, perms_list(j,:));
                    obj.G(i, j) = perm_to_idx(mat2str(composed));
                end
            end

            obj.identityIdx = 1;
            % S_n irreps require the Specht-module / Young-symmetrizer
            % construction; we have not ported that to MATLAB. The
            % previous fft(eye(n!)) fallback is mathematically wrong for
            % non-abelian S_n (n >= 3). Use the Python pipeline (which
            % computes per-irrep Fourier projections via
            % discover_at_scale.compute_irrep_fourier_power) for the
            % published symmetric-group experiments.
            error('StarGAlgebra:NotImplemented', ...
                ['initSymmetric: S_%d is non-abelian for n>=3 and ', ...
                 'requires Specht-module irrep construction not yet ', ...
                 'available in MATLAB. Use the Python pipeline at ', ...
                 'python/large_scale/starg_torch for non-abelian work.'], n);
        end

        function obj = initQuaternion(obj)
            % Q_8 = {1, -1, i, -i, j, -j, k, -k} with the standard table
            % below. Has five irreps: 4 one-dimensional (the abelianization
            % Q_8 / [Q_8, Q_8] = Z_2 x Z_2) and one two-dimensional
            % faithful representation. The Fourier matrix concatenates
            % the row-vectorization of these 5 irrep matrices, scaled by
            % 1/sqrt(8) for unitarity.
            obj.n = 8;
            obj.isCyclic = false;
            obj.G = [1 2 3 4 5 6 7 8;
                2 1 4 3 6 5 8 7;
                3 4 2 1 7 8 6 5;
                4 3 1 2 8 7 5 6;
                5 6 8 7 2 1 3 4;
                6 5 7 8 1 2 4 3;
                7 8 5 6 4 3 2 1;
                8 7 6 5 3 4 1 2];
            obj.identityIdx = 1;

            % Five irreps: (trivial, three sign-like 1-d, one faithful 2-d).
            % Encoding the elements as 1=1, 2=-1, 3=i, 4=-i, 5=j, 6=-j, 7=k, 8=-k.
            % 1-d irreps:
            %   chi_triv: all +1
            %   chi_a:   i,-i ->-1; j,-j ->+1; k,-k ->-1
            %   chi_b:   i,-i ->+1; j,-j ->-1; k,-k ->-1
            %   chi_c:   i,-i ->-1; j,-j ->-1; k,-k ->+1
            % 2-d irrep ρ: ρ(1)=I, ρ(-1)=-I, ρ(i) = [[1i,0];[0,-1i]],
            %               ρ(j) = [[0,1];[-1,0]], ρ(k)=ρ(i)*ρ(j).
            chi = zeros(8, 4);
            chi(:,1) = 1;
            chi_a = [1 1 -1 -1  1  1 -1 -1];
            chi_b = [1 1  1  1 -1 -1 -1 -1];
            chi_c = [1 1 -1 -1 -1 -1  1  1];
            chi(:,2) = chi_a; chi(:,3) = chi_b; chi(:,4) = chi_c;

            % 2-d irrep matrices
            I2 = eye(2);
            i_mat = [1i, 0; 0, -1i];
            j_mat = [0, 1; -1, 0];
            k_mat = i_mat * j_mat;
            rho2 = cell(8, 1);
            rho2{1} =  I2;       rho2{2} = -I2;
            rho2{3} =  i_mat;    rho2{4} = -i_mat;
            rho2{5} =  j_mat;    rho2{6} = -j_mat;
            rho2{7} =  k_mat;    rho2{8} = -k_mat;

            rows = zeros(8, 8);
            for g = 1:8
                row = [chi(g, 1), chi(g, 2), chi(g, 3), chi(g, 4), ...
                       rho2{g}(1,1), rho2{g}(1,2), rho2{g}(2,1), rho2{g}(2,2)];
                rows(g, :) = row;
            end
            obj.F = complex(rows) / sqrt(8);
            obj.Finv = inv(obj.F);
            obj.irrepDims = [1, 1, 1, 1, 2];
        end

        function obj = initOctahedral(obj)
            % Chiral octahedral group O of order 24. Five irreps: A_1
            % (trivial, 1-d), A_2 (sign of permutation on 4 body
            % diagonals, 1-d), E (2-d, l=2 doublet on Sym^2(R^3)/trace),
            % T_1 (3-d, the rotation matrices themselves, l=1), T_2 (3-d,
            % traceless symmetric tensor part, l=2). Mirrors
            % python/large_scale/starg_torch/octahedral.octahedral_irreps.
            obj.isCyclic = false;
            obj.n = 24;
            R = obj.octahedralRotations();

            % Build multiplication table.
            T = zeros(24, 24);
            for i = 1:24
                for j = 1:24
                    P = R{i} * R{j};
                    found = -1;
                    for k = 1:24
                        if max(abs(P(:) - R{k}(:))) < 1e-6
                            found = k; break;
                        end
                    end
                    if found < 0
                        error('octahedral product not in group');
                    end
                    T(i, j) = found;
                end
            end
            obj.G = T;
            obj.identityIdx = 1;

            % A_1 (trivial)
            A1 = ones(24, 1);
            % A_2: sign of permutation induced on the 4 body diagonals.
            diags = [ 1  1  1; 1  1 -1; 1 -1  1; -1  1  1];
            A2 = zeros(24, 1);
            for g = 1:24
                permuted = (R{g} * diags')';
                idx = zeros(4, 1);
                for ii = 1:4
                    for jj = 1:4
                        if max(abs(permuted(ii, :) - diags(jj, :))) < 1e-6 || ...
                           max(abs(permuted(ii, :) + diags(jj, :))) < 1e-6
                            idx(ii) = jj; break;
                        end
                    end
                end
                A2(g) = obj.permSign(idx);
            end

            % T_1: the rotation matrices (3-d)
            % E and T_2: from the symmetric traceless rank-2 representation.
            % 5-d basis: (xx-yy)/sqrt(2), (2zz-xx-yy)/sqrt(6), xy, xz, yz.
            B = [1, -1, 0, 0, 0, 0;
                 -1, -1, 2, 0, 0, 0;
                 0, 0, 0, 1, 0, 0;
                 0, 0, 0, 0, 1, 0;
                 0, 0, 0, 0, 0, 1];
            for r = 1:5
                B(r, :) = B(r, :) / norm(B(r, :));
            end

            E_mats = cell(24, 1);
            T2_mats = cell(24, 1);
            for g = 1:24
                Mfull = obj.sym2RepFull(R{g});      % (6,6)
                M5 = B * Mfull * B';                % (5,5) traceless
                E_mats{g} = M5(1:2, 1:2);
                T2_mats{g} = M5(3:5, 3:5);
            end

            rows = zeros(24, 24);
            for g = 1:24
                row = [A1(g), A2(g)];
                row = [row, E_mats{g}(1,1), E_mats{g}(1,2), E_mats{g}(2,1), E_mats{g}(2,2)];
                T1g = R{g};
                row = [row, T1g(1,1), T1g(1,2), T1g(1,3), ...
                            T1g(2,1), T1g(2,2), T1g(2,3), ...
                            T1g(3,1), T1g(3,2), T1g(3,3)];
                T2g = T2_mats{g};
                row = [row, T2g(1,1), T2g(1,2), T2g(1,3), ...
                            T2g(2,1), T2g(2,2), T2g(2,3), ...
                            T2g(3,1), T2g(3,2), T2g(3,3)];
                rows(g, :) = row;
            end
            obj.F = complex(rows) / sqrt(24);
            obj.Finv = inv(obj.F);
            obj.irrepDims = [1, 1, 2, 3, 3];
        end

        function R = octahedralRotations(~)
            % 24 rotation matrices: 1 identity + 9 face + 8 vertex + 6 edge.
            % Rodrigues formula inlined.
            function M = ax(axis, angle)
                ax_n = axis(:) / norm(axis);
                Kmat = [0, -ax_n(3), ax_n(2);
                        ax_n(3), 0, -ax_n(1);
                        -ax_n(2), ax_n(1), 0];
                M = eye(3) + sin(angle)*Kmat + (1-cos(angle))*(Kmat*Kmat);
            end

            R = cell(24, 1);
            R{1} = eye(3);
            idx = 2;
            axes_face = {[1;0;0], [0;1;0], [0;0;1]};
            angles = [pi/2, pi, -pi/2];
            for a = 1:3
                for ang = angles
                    R{idx} = ax(axes_face{a}, ang);
                    idx = idx + 1;
                end
            end
            diagsAxes = {[1;1;1], [1;1;-1], [1;-1;1], [-1;1;1]};
            for a = 1:4
                for ang = [2*pi/3, -2*pi/3]
                    R{idx} = ax(diagsAxes{a}, ang);
                    idx = idx + 1;
                end
            end
            edgesAxes = {[1;1;0], [1;-1;0], [1;0;1], [1;0;-1], [0;1;1], [0;1;-1]};
            for a = 1:6
                R{idx} = ax(edgesAxes{a}, pi);
                idx = idx + 1;
            end
            for k = 1:24
                Rr = round(R{k});
                if max(abs(R{k}(:) - Rr(:))) < 1e-6
                    R{k} = Rr;
                end
            end
        end

        function s = permSign(~, perm)
            n = length(perm);
            inv = 0;
            for i = 1:n
                for j = i+1:n
                    if perm(i) > perm(j); inv = inv + 1; end
                end
            end
            if mod(inv, 2) == 0; s = 1; else; s = -1; end
        end

        function M = sym2RepFull(~, R)
            % 6-d symmetric tensor representation in basis (xx,yy,zz,xy,xz,yz).
            pairs = [1 1; 2 2; 3 3; 1 2; 1 3; 2 3];
            M = zeros(6, 6);
            for i = 1:6
                a = pairs(i, 1); b = pairs(i, 2);
                for j = 1:6
                    c = pairs(j, 1); d = pairs(j, 2);
                    if c == d
                        M(i, j) = R(a, c) * R(b, d);
                    else
                        M(i, j) = R(a, c) * R(b, d) + R(a, d) * R(b, c);
                    end
                end
            end
        end

        function obj = initProduct(obj, groups)
            k = length(groups);
            orders = cellfun(@(g) g.n, groups);
            obj.n = prod(orders);
            obj.isCyclic = false;

            obj.G = zeros(obj.n, obj.n);
            for a = 1:obj.n
                a_idx = obj.linearToMultiIndex(a, orders);
                for b = 1:obj.n
                    b_idx = obj.linearToMultiIndex(b, orders);
                    c_idx = zeros(1, k);
                    for i = 1:k
                        c_idx(i) = groups{i}.G(a_idx(i), b_idx(i));
                    end
                    obj.G(a, b) = obj.multiToLinearIndex(c_idx, orders);
                end
            end

            obj.identityIdx = 1;
            obj.F = groups{1}.F;
            obj.Finv = groups{1}.Finv;
            for i = 2:k
                obj.F = kron(obj.F, groups{i}.F);
                obj.Finv = kron(obj.Finv, groups{i}.Finv);
            end
        end

        function idx = linearToMultiIndex(~, lin, orders)
            k = length(orders);
            idx = zeros(1, k);
            lin = lin - 1;
            for i = k:-1:1
                idx(i) = mod(lin, orders(i)) + 1;
                lin = floor(lin / orders(i));
            end
        end

        function lin = multiToLinearIndex(~, idx, orders)
            k = length(orders);
            lin = 0;
            for i = 1:k
                lin = lin * orders(i) + (idx(i) - 1);
            end
            lin = lin + 1;
        end

        function obj = initFromTable(obj, multTable)
            obj.n = size(multTable, 1);
            obj.G = multTable;
            obj.isCyclic = false;
            % Detect whether the table represents an abelian group.
            isAb = isequal(multTable, multTable.');
            if isAb
                % For abelian groups the standard DFT diagonalizes the
                % regular representation up to a permutation; this is the
                % conservative fallback.
                obj.F = fft(eye(obj.n));
                obj.Finv = conj(obj.F) / obj.n;
            else
                % Non-abelian: a correct generalized Fourier transform
                % requires the irrep matrices, which aren't recoverable
                % from the table alone without numerical block-
                % diagonalization (not yet implemented in MATLAB).
                % Use the dedicated initializers
                % (StarGAlgebra('octahedral'), 'dihedral', 'quaternion')
                % which carry the correct irrep matrices, or use the
                % Python pipeline at python/large_scale/starg_torch which
                % handles arbitrary non-abelian groups via per-irrep
                % block decomposition.
                warning('StarGAlgebra:NonAbelianTable', ...
                    ['initFromTable: the supplied multiplication table ', ...
                     'is non-abelian; a correct generalized Fourier ', ...
                     'matrix requires the irrep matrices and is not ', ...
                     'derivable from the table alone in this MATLAB ', ...
                     'implementation. Falling back to the regular ', ...
                     'representation as F (only valid in spirit; ', ...
                     'incorrect for SVD/Fourier-domain operations). ', ...
                     'Prefer StarGAlgebra(''octahedral''), ', ...
                     'StarGAlgebra(''dihedral'', n), or ', ...
                     'StarGAlgebra(''quaternion'').']);
                obj.F = fft(eye(obj.n));
                obj.Finv = conj(obj.F) / obj.n;
            end
        end

        %% Utility
        function I = identity(obj, m)
            I = zeros(m, m, obj.n);
            I(:,:,obj.identityIdx) = eye(m);
        end

        function nrm = tensorNorm(~, A)
            nrm = sqrt(sum(abs(A(:)).^2));
        end

        %% Benchmark
        function benchmark(obj, sizes)
            if nargin < 2
                sizes = [10, 20, 50, 100];
            end

            fprintf('\n=== Performance Benchmark ===\n');
            fprintf('Group: n=%d, Cyclic=%d, GPU=%d\n', obj.n, obj.isCyclic, obj.useGPU);
            fprintf('%-10s %-15s %-15s %-10s\n', 'Size', 'Direct (s)', 'Optimized (s)', 'Speedup');
            fprintf('%s\n', repmat('-', 1, 55));

            for sz = sizes
                A = randn(sz, sz, obj.n);
                B = randn(sz, sz, obj.n);

                obj.starG(A, B); % Warm up

                if sz <= 30
                    tic;
                    for rep = 1:3
                        obj.starG_direct(A, B);
                    end
                    t_direct = toc / 3;
                else
                    t_direct = NaN;
                end

                tic;
                for rep = 1:10
                    obj.starG(A, B);
                end
                t_opt = toc / 10;

                if ~isnan(t_direct)
                    fprintf('%-10d %-15.4f %-15.4f %-10.1fx\n', sz, t_direct, t_opt, t_direct/t_opt);
                else
                    fprintf('%-10d %-15s %-15.4f %-10s\n', sz, 'N/A', t_opt, '-');
                end
            end
        end

        %% Verification Suite
        function runAllTests(obj)
            fprintf('========================================\n');
            fprintf('StarGAlgebra Verification Suite\n');
            fprintf('Group order: %d, Cyclic: %d, Abelian: %d\n', ...
                obj.n, obj.isCyclic, obj.isAbelian);
            fprintf('Identity at index: %d, GPU: %d\n', obj.identityIdx, obj.useGPU);
            fprintf('========================================\n\n');

            obj.testGroupAxioms();
            obj.testConvolutionTensor();
            obj.testConvolutionMethods();
            obj.testStarGMethods();
            obj.testConjugateTranspose();
            obj.testIdentity();
            obj.testAssociativity();
            obj.testSVD();

            fprintf('\n========================================\n');
            fprintf('All tests completed.\n');
            fprintf('========================================\n');
        end

        function testGroupAxioms(obj)
            fprintf('Test 1: Group Axioms\n');
            passed = true;
            e = obj.identityIdx;

            for a = 1:obj.n
                if obj.G(e, a) ~= a || obj.G(a, e) ~= a
                    fprintf('  FAIL: Identity for %d\n', a);
                    passed = false;
                end
            end

            for a = 1:obj.n
                a_inv = obj.invTable(a);
                if obj.G(a, a_inv) ~= e || obj.G(a_inv, a) ~= e
                    fprintf('  FAIL: Inverse for %d\n', a);
                    passed = false;
                end
            end

            for t = 1:min(obj.n^3, 500)
                a = randi(obj.n); b = randi(obj.n); c = randi(obj.n);
                if obj.G(obj.G(a,b), c) ~= obj.G(a, obj.G(b,c))
                    fprintf('  FAIL: Associativity\n');
                    passed = false;
                    break;
                end
            end

            if passed
                fprintf('  PASS\n');
            end
        end

        function testConvolutionTensor(obj)
            fprintf('\nTest 2: Convolution Tensor\n');
            passed = true;

            for a = 1:obj.n
                for b = 1:obj.n
                    c_exp = obj.G(a, b);
                    for c = 1:obj.n
                        if obj.convTensor(a, b, c) ~= (c == c_exp)
                            passed = false;
                        end
                    end
                end
            end

            if passed
                fprintf('  PASS\n');
            else
                fprintf('  FAIL\n');
            end
        end

        function testConvolutionMethods(obj)
            fprintf('\nTest 3: 1D Convolution\n');

            a = randn(obj.n, 1);
            b = randn(obj.n, 1);

            c1 = obj.convolve_direct(a, b);
            c2 = obj.convolve_inverse(a, b);
            c3 = obj.convolve(a, b);

            err1 = norm(c1 - c2) / norm(c1);
            err2 = norm(c1 - c3) / norm(c1);

            fprintf('  Direct vs Inverse: %.2e\n', err1);
            fprintf('  Direct vs Optimized: %.2e\n', err2);

            tol = 1e-6;
            if err1 < tol && err2 < tol
                fprintf('  PASS\n');
            else
                fprintf('  FAIL\n');
            end
        end

        function testStarGMethods(obj)
            fprintf('\nTest 4: StarG Product\n');

            A = randn(3, 4, obj.n);
            B = randn(4, 2, obj.n);

            tic;
            C_direct = obj.starG_direct(A, B);
            t1 = toc;

            tic;
            C_main = obj.starG(A, B);
            t2 = toc;

            err = norm(C_direct(:) - C_main(:)) / norm(C_direct(:));
            fprintf('  Direct vs Main: error=%.2e (%.4fs vs %.4fs)\n', err, t1, t2);

            tol = 1e-6;
            if err < tol
                fprintf('  PASS\n');
            else
                fprintf('  FAIL\n');
            end
        end

        function testConjugateTranspose(obj)
            fprintf('\nTest 5: Conjugate Transpose\n');

            A = randn(3, 4, obj.n) + 1i*randn(3, 4, obj.n);

            Ah1 = obj.conjugateTranspose(A);
            Ah2 = obj.conjugateTranspose_fast(A);

            err = norm(Ah1(:) - Ah2(:)) / norm(Ah1(:));

            if err < 1e-10
                fprintf('  PASS (error=%.2e)\n', err);
            else
                fprintf('  FAIL (error=%.2e)\n', err);
            end
        end

        function testIdentity(obj)
            fprintf('\nTest 6: Identity\n');

            m = 4;
            I = obj.identity(m);
            A = randn(m, m, obj.n);

            IA = obj.starG(I, A);
            AI = obj.starG(A, I);

            err1 = norm(IA(:) - A(:)) / norm(A(:));
            err2 = norm(AI(:) - A(:)) / norm(A(:));

            fprintf('  ||I*A - A||/||A|| = %.2e\n', err1);
            fprintf('  ||A*I - A||/||A|| = %.2e\n', err2);

            tol = 1e-10;
            if err1 < tol && err2 < tol
                fprintf('  PASS\n');
            else
                fprintf('  FAIL\n');
            end
        end

        function testAssociativity(obj)
            fprintf('\nTest 7: Associativity\n');

            A = randn(2, 3, obj.n);
            B = randn(3, 2, obj.n);
            C = randn(2, 2, obj.n);

            AB = obj.starG(A, B);
            ABC_left = obj.starG(AB, C);

            BC = obj.starG(B, C);
            ABC_right = obj.starG(A, BC);

            err = norm(ABC_left(:) - ABC_right(:)) / norm(ABC_left(:));

            if err < 1e-10
                fprintf('  PASS (error=%.2e)\n', err);
            else
                fprintf('  FAIL (error=%.2e)\n', err);
            end
        end

        function testSVD(obj)
            fprintf('\nTest 8: SVD\n');

            if ~obj.isCyclic
                fprintf('  SKIP (non-cyclic)\n');
                return;
            end

            A = randn(4, 3, obj.n);
            [U, S, V] = obj.starG_SVD(A);

            Vh = obj.conjugateTranspose(V);
            A_rec = obj.starG(obj.starG(U, S), Vh);

            err = norm(A(:) - A_rec(:)) / norm(A(:));

            if err < 1e-10
                fprintf('  PASS (error=%.2e)\n', err);
            else
                fprintf('  FAIL (error=%.2e)\n', err);
            end
        end
    end
end

%% ========================================================================
%% ADDITIONAL METHODS FOR StarGAlgebra CLASS
%% Add these methods to your StarGAlgebra.m file
%% ========================================================================

function [U, S, V, sv_invariant] = starG_SVD_stable(obj, A)
%% starG_SVD_stable - Numerically stable ★_G-SVD with exact invariants
%%
%% This version returns both the full decomposition AND a set of
%% numerically exact invariant quantities.
%%
%% OUTPUTS:
%%   U, S, V      - Standard ★_G-SVD factors
%%   sv_invariant - Struct containing exactly invariant quantities:
%%                  .magnitudes  - |singular values| in each Fourier slice
%%                  .power_spec  - Power spectrum of input
%%                  .trace_invs  - Trace invariants

[l, m, n] = size(A);
minlm = min(l, m);

sv_invariant = struct();

if obj.isCyclic
    % === Fourier domain computation ===
    Ahat = fft(A, [], 3);

    n_freq = floor(n/2) + 1;  % One-sided spectrum

    Uhat = zeros(l, minlm, n);
    Shat = zeros(minlm, minlm, n);
    Vhat = zeros(m, minlm, n);

    % Store invariant singular value magnitudes
    sv_mags = zeros(minlm, n_freq);

    for k = 1:n
        slice = Ahat(:,:,k);
        [Uk, Sk, Vk] = svd(slice, 'econ');

        kk = size(Uk, 2);
        Uhat(:, 1:kk, k) = Uk;
        Shat(1:kk, 1:kk, k) = Sk;
        Vhat(:, 1:kk, k) = Vk;

        % Store magnitudes for one-sided spectrum
        if k <= n_freq
            sv_mags(1:kk, k) = abs(diag(Sk));
        end
    end

    U = ifft(Uhat, [], 3);
    S = ifft(Shat, [], 3);
    V = ifft(Vhat, [], 3);

    % Sort singular values within each frequency for consistency
    sv_invariant.magnitudes = sort(sv_mags, 1, 'descend');

else
    % === Direct computation for non-cyclic groups ===
    U = zeros(l, minlm, n);
    S = zeros(minlm, minlm, n);
    V = zeros(m, minlm, n);

    sv_mags = zeros(minlm, n);

    for g = 1:n
        [Ug, Sg, Vg] = svd(A(:,:,g), 'econ');
        kk = size(Ug, 2);
        U(:, 1:kk, g) = Ug;
        S(1:kk, 1:kk, g) = Sg;
        V(:, 1:kk, g) = Vg;
        sv_mags(1:kk, g) = diag(Sg);
    end

    sv_invariant.magnitudes = sort(sv_mags, 1, 'descend');
end

% === Additional invariants ===

% Power spectrum (exact)
sv_invariant.power_spec = abs(fft(A, [], 3)).^2;

% Trace invariants (exact)
sv_invariant.trace_invs = zeros(4, 1);
A_flat = reshape(A, [l*m, n]);
sv_invariant.trace_invs(1) = sum(A_flat(:).^2);  % Frobenius norm²

for g = 1:n
    Ag = A(:,:,g);
    sv_invariant.trace_invs(2) = sv_invariant.trace_invs(2) + trace(Ag * Ag');
    sv_invariant.trace_invs(3) = sv_invariant.trace_invs(3) + trace((Ag * Ag')^2);
end
sv_invariant.trace_invs(4) = sum(svd(A_flat));  % Nuclear norm

if isreal(A)
    U = real(U);
    S = real(S);
    V = real(V);
end
end


function feat = extractInvariantFeatures_stable(obj, A)
%% extractInvariantFeatures_stable - Extract exactly invariant features
%%
%% Uses only algebraically exact invariants, with numerical safeguards.

[l, m, n] = size(A);

% Get stable SVD with invariants
[~, ~, ~, sv_inv] = obj.starG_SVD_stable(A);

% Collect features
feat = [];

% 1. Singular value magnitudes (flattened, sorted)
sv_flat = sv_inv.magnitudes(:)';
feat = [feat, sv_flat];

% 2. Power spectrum statistics
ps = sv_inv.power_spec;
ps_sum = sum(ps(:));
ps_max = max(ps(:));
ps_mean = mean(ps(:));
ps_std = std(ps(:));
feat = [feat, ps_sum, ps_max, ps_mean, ps_std];

% 3. Trace invariants
feat = [feat, sv_inv.trace_invs'];

% 4. Gram matrix eigenvalues for each slice (averaged)
eig_sum = zeros(1, l);
for g = 1:n
    Ag = A(:,:,g);
    eig_g = sort(real(eig(Ag * Ag')), 'descend');
    if length(eig_g) < l
        eig_g = [eig_g; zeros(l - length(eig_g), 1)];
    end
    eig_sum = eig_sum + eig_g(1:l)';
end
feat = [feat, eig_sum / n];

% Round for exact invariance
feat = round(feat, 12);
end