%% ============================================================================
%% GPU-ACCELERATED StarGAlgebra CLASS
%% ============================================================================
% LH & SU 2026
%% ============================================================================
classdef StarGAlgebra
properties
G
n
F
Finv
irrepDims
convTensor
invTable
useParallel
useGPU
isAbelian
isCyclic
identityIdx
% GPU arrays
G_gpu
convTensor_gpu
invTable_gpu
F_gpu
Finv_gpu
end
methods
function obj = StarGAlgebra(groupType, varargin)
obj.useGPU = false;
obj.useParallel = ~isempty(gcp('nocreate'));
groupParam = [];
if ~isempty(varargin)
groupParam = varargin{1};
end
switch lower(groupType)
case 'cyclic'
obj = obj.initCyclic(groupParam);
case 'dihedral'
obj = obj.initDihedral(groupParam);
case 'symmetric'
obj = obj.initSymmetric(groupParam);
case 'klein4'
obj = obj.initKlein4();
case 'octahedral'
obj = obj.initOctahedral();
case 'quaternion'
obj = obj.initQuaternion();
case 'product'
obj = obj.initProduct(groupParam);
case 'table'
obj = obj.initFromTable(groupParam);
otherwise
error('Unknown group type: %s', groupType);
end
obj = obj.buildConvolutionTensor();
obj = obj.buildInverseTable();
obj = obj.findIdentity();
obj.isAbelian = obj.checkAbelian();
end
%% GPU Methods
function obj = enableGPU(obj)
if gpuDeviceCount > 0
obj.useGPU = true;
obj = obj.initGPUArrays();
fprintf('GPU enabled: %s\n', gpuDevice().Name);
else
warning('No GPU available');
end
end
function obj = initGPUArrays(obj)
obj.G_gpu = gpuArray(int32(obj.G));
obj.convTensor_gpu = gpuArray(obj.convTensor);
obj.invTable_gpu = gpuArray(int32(obj.invTable));
obj.F_gpu = gpuArray(obj.F);
obj.Finv_gpu = gpuArray(obj.Finv);
end
function obj = disableGPU(obj)
obj.useGPU = false;
obj.G_gpu = [];
obj.convTensor_gpu = [];
obj.invTable_gpu = [];
obj.F_gpu = [];
obj.Finv_gpu = [];
end
%% Core Setup
function obj = findIdentity(obj)
for e = 1:obj.n
isId = true;
for a = 1:obj.n
if obj.G(e, a) ~= a || obj.G(a, e) ~= a
isId = false;
break;
end
end
if isId
obj.identityIdx = e;
return;
end
end
error('No identity element found');
end
function obj = buildInverseTable(obj)
obj.invTable = zeros(obj.n, 1);
e = 1;
for candidate = 1:obj.n
if all(obj.G(candidate, :) == 1:obj.n)
e = candidate;
break;
end
end
for a = 1:obj.n
for b = 1:obj.n
if obj.G(a, b) == e
obj.invTable(a) = b;
break;
end
end
end
end
function obj = buildConvolutionTensor(obj)
ng = obj.n;
obj.convTensor = zeros(ng, ng, ng);
for a = 1:ng
for b = 1:ng
c = obj.G(a, b);
obj.convTensor(a, b, c) = 1;
end
end
end
function isAb = checkAbelian(obj)
isAb = isequal(obj.G, obj.G');
end
%% Convolution Methods
function c = convolve_direct(obj, a, b)
a = a(:);
b = b(:);
if obj.useGPU
a = gpuArray(a);
b = gpuArray(b);
T = obj.convTensor_gpu;
else
T = obj.convTensor;
end
ab = a * b';
c = squeeze(sum(sum(T .* ab, 1), 2));
if obj.useGPU
c = gather(c);
end
end
function c = convolve_inverse(obj, a, b)
a = a(:);
b = b(:);
c = zeros(obj.n, 1);
for c_idx = 1:obj.n
for g = 1:obj.n
g_inv = obj.invTable(g);
g_inv_c = obj.G(g_inv, c_idx);
c(c_idx) = c(c_idx) + a(g) * b(g_inv_c);
end
end
end
function c = convolve(obj, a, b)
a = a(:);
b = b(:);
if obj.isCyclic
if obj.useGPU
a = gpuArray(a);
b = gpuArray(b);
c = gather(ifft(fft(a) .* fft(b)));
else
c = ifft(fft(a) .* fft(b));
end
if isreal(a) && isreal(b)
c = real(c);
end
else
c = obj.convolve_direct(a, b);
end
end
%% Star-G Methods
function C = starG_direct(obj, A, B)
[l, m, ng] = size(A);
[~, p, ~] = size(B);
if obj.useGPU
A = gpuArray(A);
B = gpuArray(B);
T = obj.convTensor_gpu;
C = gpuArray.zeros(l, p, ng);
else
T = obj.convTensor;
C = zeros(l, p, ng);
end
for i = 1:l
for j = 1:p
for k = 1:m
a_ik = squeeze(A(i, k, :));
b_kj = squeeze(B(k, j, :));
ab = a_ik(:) * b_kj(:)';
conv_result = squeeze(sum(sum(T .* ab, 1), 2));
C(i, j, :) = C(i, j, :) + reshape(conv_result, [1, 1, ng]);
end
end
end
if obj.useGPU
C = gather(C);
end
end
function C = starG_fourier(obj, A, B)
[l, m, n] = size(A);
[~, p, ~] = size(B);
if obj.useGPU
A = gpuArray(A);
B = gpuArray(B);
end
Ahat = fft(A, [], 3);
Bhat = fft(B, [], 3);
if exist('pagemtimes', 'builtin') || exist('pagemtimes', 'file')
Chat = pagemtimes(Ahat, Bhat);
else
if obj.useGPU
Chat = gpuArray.zeros(l, p, n);
else
Chat = zeros(l, p, n);
end
for k = 1:n
Chat(:,:,k) = Ahat(:,:,k) * Bhat(:,:,k);
end
end
C = ifft(Chat, [], 3);
if obj.useGPU
C = gather(C);
end
if isreal(A) && isreal(B)
C = real(C);
end
end
function C = starG(obj, A, B)
szA = size(A);
szB = size(B);
if ndims(A) == 2
A = reshape(A, [szA(1), szA(2), 1]);
if size(A,3) == 1 && obj.n > 1
A = repmat(A, [1, 1, obj.n]);
end
end
if ndims(B) == 2
B = reshape(B, [szB(1), szB(2), 1]);
if size(B,3) == 1 && obj.n > 1
B = repmat(B, [1, 1, obj.n]);
end
end
[~, m, n_A] = size(A);
[m2, ~, n_B] = size(B);
assert(m == m2, 'Inner dimensions must match');
assert(n_A == obj.n && n_B == obj.n, 'Mode-3 must equal group order');
if obj.isCyclic
C = obj.starG_fourier(A, B);
else
C = obj.starG_direct(A, B);
end
end
%% Batch Operations
function C = starG_batch(obj, A, B)
[l, m, ng, batch] = size(A);
[~, p, ~, ~] = size(B);
if obj.useGPU && obj.isCyclic
A = gpuArray(A);
B = gpuArray(B);
Ahat = fft(A, [], 3);
Bhat = fft(B, [], 3);
Chat = gpuArray.zeros(l, p, ng, batch);
for b_idx = 1:batch
if exist('pagemtimes', 'builtin')
Chat(:,:,:,b_idx) = pagemtimes(Ahat(:,:,:,b_idx), Bhat(:,:,:,b_idx));
else
for k = 1:ng
Chat(:,:,k,b_idx) = Ahat(:,:,k,b_idx) * Bhat(:,:,k,b_idx);
end
end
end
C = gather(real(ifft(Chat, [], 3)));
else
C = zeros(l, p, ng, batch);
for b_idx = 1:batch
C(:,:,:,b_idx) = obj.starG(A(:,:,:,b_idx), B(:,:,:,b_idx));
end
end
end
function C = starG_old(obj, A, B)
% ★_G product - Convolutional tensor product
% A: l x m x n, B: m x p x n -> C: l x p x n
szA = size(A);
szB = size(B);
% Handle 2D inputs
if ndims(A) == 2
A = reshape(A, [szA(1), szA(2), 1]);
if size(A,3) == 1 && obj.n > 1
A = repmat(A, [1, 1, obj.n]);
end
end
if ndims(B) == 2
B = reshape(B, [szB(1), szB(2), 1]);
if size(B,3) == 1 && obj.n > 1
B = repmat(B, [1, 1, obj.n]);
end
end
[l, m, ~] = size(A);
[m2, p, ~] = size(B);
n = obj.n;
assert(m == m2, 'Inner dimensions must match');
% Transform along mode-3 using group Fourier matrix
% Reshape for matrix multiplication with F
A_reshape = reshape(permute(A, [3, 1, 2]), n, []); % n x (l*m)
Ahat_reshape = obj.F' * A_reshape; % n x (l*m)
Ahat = permute(reshape(Ahat_reshape, n, l, m), [2, 3, 1]); % l x m x n
B_reshape = reshape(permute(B, [3, 1, 2]), n, []); % n x (m*p)
Bhat_reshape = obj.F' * B_reshape; % n x (m*p)
Bhat = permute(reshape(Bhat_reshape, n, m, p), [2, 3, 1]); % m x p x n
% Multiply in transform domain with Peter-Weyl structure
% For each output position k, compute using convTensor
Chat = zeros(l, p, n);
% Vectorized computation using convTensor
for k = 1:n
Ck_slice = zeros(l, p);
for x = 1:n
for y = 1:n
if obj.convTensor(x, y, k) ~= 0
Ck_slice = Ck_slice + obj.convTensor(x,y,k) * (Ahat(:,:,x) * Bhat(:,:,y));
end
end
end
Chat(:,:,k) = Ck_slice;
end
% Transform back
Chat_reshape = reshape(permute(Chat, [3, 1, 2]), n, []); % n x (l*p)
C_reshape = obj.Finv' * Chat_reshape; % n x (l*p)
C = permute(reshape(C_reshape, n, l, p), [2, 3, 1]); % l x p x n
if isreal(A) && isreal(B)
C = real(C);
end
end
%% Conjugate Transpose
function Ah = conjugateTranspose(obj, A)
[m, p, n] = size(A);
Ah = zeros(p, m, n);
for i = 1:p
for j = 1:m
for g = 1:n
g_inv = obj.invTable(g);
Ah(i, j, g) = conj(A(j, i, g_inv));
end
end
end
end
function Ah = conjugateTranspose_fast(obj, A)
Ah = permute(conj(A), [2, 1, 3]);
Ah = Ah(:, :, obj.invTable);
end
%% SVD
function [U, S, V] = starG_SVD(obj, A)
% Star_G-SVD via generalized Fourier transform.
% Algorithm 1 from the theory notes:
% 1. Transform to Fourier domain using G.F
% 2. SVD each Fourier slice
% 3. Transform back using G.Finv
%
% For cyclic groups, G.F = fft(eye(n)) so this is equivalent
% to the FFT-based algorithm. For dihedral, symmetric, etc.,
% G.F is the generalized Fourier matrix.
[l, m, n] = size(A);
minlm = min(l, m);
% --- Forward generalized Fourier transform along mode 3 ---
if obj.isCyclic
Ahat = fft(A, [], 3); % fast path for cyclic
else
% General: multiply each tube by G.F
A_r = reshape(permute(A, [3,1,2]), n, []); % n x (l*m)
Ahat_r = obj.F * A_r; % n x (l*m)
Ahat = permute(reshape(Ahat_r, n, l, m), [2,3,1]);
end
% --- SVD of each Fourier slice ---
Uhat = zeros(l, minlm, n);
Shat = zeros(minlm, minlm, n);
Vhat = zeros(m, minlm, n);
for i = 1:n
slice = Ahat(:,:,i);
if obj.useGPU, slice = gather(slice); end
[Ui, Si, Vi] = svd(slice, 'econ');
k = size(Ui, 2);
Uhat(:, 1:k, i) = Ui;
Shat(1:k, 1:k, i) = Si;
Vhat(:, 1:k, i) = Vi;
end
% --- Inverse generalized Fourier transform ---
if obj.isCyclic
U = ifft(Uhat, [], 3);
S = ifft(Shat, [], 3);
V = ifft(Vhat, [], 3);
else
for arr_cell = {{'U', Uhat, l, minlm}, ...
{'S', Shat, minlm, minlm}, ...
{'V', Vhat, m, minlm}}
c = arr_cell{1};
arr = c{2}; r1 = c{3}; r2 = c{4};
arr_r = reshape(permute(arr, [3,1,2]), n, []);
out_r = obj.Finv * arr_r;
switch c{1}
case 'U', U = permute(reshape(out_r, n, r1, r2), [2,3,1]);
case 'S', S = permute(reshape(out_r, n, r1, r2), [2,3,1]);
case 'V', V = permute(reshape(out_r, n, r1, r2), [2,3,1]);
end
end
end
if isreal(A)
U = real(U); S = real(S); V = real(V);
end
end
function Ak = truncate(obj, A, k)
[l, m, n] = size(A);
k = min(k, min(l, m));
% --- Forward generalized Fourier transform ---
if obj.isCyclic
Ahat = fft(A, [], 3);
else
A_r = reshape(permute(A, [3,1,2]), n, []);
Ahat_r = obj.F * A_r;
Ahat = permute(reshape(Ahat_r, n, l, m), [2,3,1]);
end
Akhat = zeros(l, m, n);
for i = 1:n
slice = Ahat(:,:,i);
if obj.useGPU, slice = gather(slice); end
[Ui, Si, Vi] = svd(slice, 'econ');
ki = min(k, size(Si, 1));
Akhat(:,:,i) = Ui(:,1:ki) * Si(1:ki,1:ki) * Vi(:,1:ki)';
end
% --- Inverse generalized Fourier transform ---
if obj.isCyclic
Ak = ifft(Akhat, [], 3);
else
Ak_r = reshape(permute(Akhat, [3,1,2]), n, []);
Akinv_r = obj.Finv * Ak_r;
Ak = permute(reshape(Akinv_r, n, l, m), [2,3,1]);
end
if isreal(A)
Ak = real(Ak);
end
end
%% Group Initialization
function obj = initCyclic(obj, n)
obj.n = n;
obj.isCyclic = true;
[I, J] = meshgrid(0:n-1, 0:n-1);
obj.G = mod(I + J, n) + 1;
obj.identityIdx = 1;
obj.F = fft(eye(n));
obj.Finv = conj(obj.F) / n;
end
function obj = initKlein4(obj)
obj.n = 4;
obj.isCyclic = false;
obj.G = [1 2 3 4; 2 1 4 3; 3 4 1 2; 4 3 2 1];
obj.identityIdx = 1;
obj.F = [1 1 1 1; 1 1 -1 -1; 1 -1 1 -1; 1 -1 -1 1];
obj.Finv = obj.F / 4;
end
function obj = initDihedral(obj, n)
% Build D_n = <r, s | r^n = s^2 = 1, srs = r^{-1}>.
% Multiplication-table convention: rotations 1..n, reflections
% n+1..2n. Then the generalized Fourier matrix F has rows
% concatenating the row-vectorization of the 1-d trivial, the
% 1-d sign, the (n-1)/2 two-dimensional (or (n-2)/2 + two extra
% 1-d for n even) irrep matrices.
obj.n = 2*n;
obj.isCyclic = false;
obj.G = zeros(2*n, 2*n);
for i = 1:2*n
for j = 1:2*n
if i <= n && j <= n
obj.G(i,j) = mod(i-1 + j-1, n) + 1;
elseif i <= n && j > n
obj.G(i,j) = mod(j-n-1 + i-1, n) + n + 1;
elseif i > n && j <= n
obj.G(i,j) = mod(i-n-1 - (j-1), n) + n + 1;
else
obj.G(i,j) = mod((i-n-1) - (j-n-1), n) + 1;
end
end
end
obj.identityIdx = 1;
% Build F from D_n irreps. Mirrors python/large_scale/starg_torch/
% algebra.py:_build_F_dihedral. Two 1-d irreps (trivial, sign),
% (n-1)/2 two-d irreps (k = 1, .., floor((n-1)/2)). For even n,
% two more 1-d irreps.
order = 2*n;
n_2d = floor((n - 1) / 2);
irrep_dims_local = [1, 1, repmat(2, 1, n_2d)];
if mod(n, 2) == 0
irrep_dims_local = [irrep_dims_local, 1, 1];
end
rows = [];
for g = 1:order
r = mod(g - 1, n);
s = floor((g - 1) / n);
row = [];
row = [row, 1]; % trivial
if s == 0, row = [row, 1]; else, row = [row, -1]; end % sign
for k = 1:n_2d
ang = 2 * pi * k * r / n;
if s == 0
Mblk = [cos(ang), -sin(ang); sin(ang), cos(ang)];
else
Mblk = [cos(ang), sin(ang); sin(ang), -cos(ang)];
end
row = [row, Mblk(1,1), Mblk(1,2), Mblk(2,1), Mblk(2,2)];
end
if mod(n, 2) == 0
if s == 0, sgn1 = 1; else, sgn1 = -1; end
if mod(r, 2) == 0, sgn2 = 1; else, sgn2 = -1; end
row = [row, sgn1 * sgn2, sgn2];
end
rows = [rows; row];
end
obj.F = complex(rows / sqrt(order));
obj.Finv = inv(obj.F);
obj.irrepDims = irrep_dims_local;
end
function obj = initSymmetric(obj, n)
obj.isCyclic = false;
perms_list = perms(1:n);
obj.n = size(perms_list, 1);
identity_perm = 1:n;
identity_idx = find(all(perms_list == identity_perm, 2), 1);
if identity_idx ~= 1
perms_list([1, identity_idx], :) = perms_list([identity_idx, 1], :);
end
perm_to_idx = containers.Map('KeyType', 'char', 'ValueType', 'int32');
for i = 1:obj.n
perm_to_idx(mat2str(perms_list(i,:))) = i;
end
obj.G = zeros(obj.n, obj.n);
for i = 1:obj.n
for j = 1:obj.n
composed = perms_list(i, perms_list(j,:));
obj.G(i, j) = perm_to_idx(mat2str(composed));
end
end
obj.identityIdx = 1;
% S_n irreps require the Specht-module / Young-symmetrizer
% construction; we have not ported that to MATLAB. The
% previous fft(eye(n!)) fallback is mathematically wrong for
% non-abelian S_n (n >= 3). Use the Python pipeline (which
% computes per-irrep Fourier projections via
% discover_at_scale.compute_irrep_fourier_power) for the
% published symmetric-group experiments.
error('StarGAlgebra:NotImplemented', ...
['initSymmetric: S_%d is non-abelian for n>=3 and ', ...
'requires Specht-module irrep construction not yet ', ...
'available in MATLAB. Use the Python pipeline at ', ...
'python/large_scale/starg_torch for non-abelian work.'], n);
end
function obj = initQuaternion(obj)
% Q_8 = {1, -1, i, -i, j, -j, k, -k} with the standard table
% below. Has five irreps: 4 one-dimensional (the abelianization
% Q_8 / [Q_8, Q_8] = Z_2 x Z_2) and one two-dimensional
% faithful representation. The Fourier matrix concatenates
% the row-vectorization of these 5 irrep matrices, scaled by
% 1/sqrt(8) for unitarity.
obj.n = 8;
obj.isCyclic = false;
obj.G = [1 2 3 4 5 6 7 8;
2 1 4 3 6 5 8 7;
3 4 2 1 7 8 6 5;
4 3 1 2 8 7 5 6;
5 6 8 7 2 1 3 4;
6 5 7 8 1 2 4 3;
7 8 5 6 4 3 2 1;
8 7 6 5 3 4 1 2];
obj.identityIdx = 1;
% Five irreps: (trivial, three sign-like 1-d, one faithful 2-d).
% Encoding the elements as 1=1, 2=-1, 3=i, 4=-i, 5=j, 6=-j, 7=k, 8=-k.
% 1-d irreps:
% chi_triv: all +1
% chi_a: i,-i ->-1; j,-j ->+1; k,-k ->-1
% chi_b: i,-i ->+1; j,-j ->-1; k,-k ->-1
% chi_c: i,-i ->-1; j,-j ->-1; k,-k ->+1
% 2-d irrep ρ: ρ(1)=I, ρ(-1)=-I, ρ(i) = [[1i,0];[0,-1i]],
% ρ(j) = [[0,1];[-1,0]], ρ(k)=ρ(i)*ρ(j).
chi = zeros(8, 4);
chi(:,1) = 1;
chi_a = [1 1 -1 -1 1 1 -1 -1];
chi_b = [1 1 1 1 -1 -1 -1 -1];
chi_c = [1 1 -1 -1 -1 -1 1 1];
chi(:,2) = chi_a; chi(:,3) = chi_b; chi(:,4) = chi_c;
% 2-d irrep matrices
I2 = eye(2);
i_mat = [1i, 0; 0, -1i];
j_mat = [0, 1; -1, 0];
k_mat = i_mat * j_mat;
rho2 = cell(8, 1);
rho2{1} = I2; rho2{2} = -I2;
rho2{3} = i_mat; rho2{4} = -i_mat;
rho2{5} = j_mat; rho2{6} = -j_mat;
rho2{7} = k_mat; rho2{8} = -k_mat;
rows = zeros(8, 8);
for g = 1:8
row = [chi(g, 1), chi(g, 2), chi(g, 3), chi(g, 4), ...
rho2{g}(1,1), rho2{g}(1,2), rho2{g}(2,1), rho2{g}(2,2)];
rows(g, :) = row;
end
obj.F = complex(rows) / sqrt(8);
obj.Finv = inv(obj.F);
obj.irrepDims = [1, 1, 1, 1, 2];
end
function obj = initOctahedral(obj)
% Chiral octahedral group O of order 24. Five irreps: A_1
% (trivial, 1-d), A_2 (sign of permutation on 4 body
% diagonals, 1-d), E (2-d, l=2 doublet on Sym^2(R^3)/trace),
% T_1 (3-d, the rotation matrices themselves, l=1), T_2 (3-d,
% traceless symmetric tensor part, l=2). Mirrors
% python/large_scale/starg_torch/octahedral.octahedral_irreps.
obj.isCyclic = false;
obj.n = 24;
R = obj.octahedralRotations();
% Build multiplication table.
T = zeros(24, 24);
for i = 1:24
for j = 1:24
P = R{i} * R{j};
found = -1;
for k = 1:24
if max(abs(P(:) - R{k}(:))) < 1e-6
found = k; break;
end
end
if found < 0
error('octahedral product not in group');
end
T(i, j) = found;
end
end
obj.G = T;
obj.identityIdx = 1;
% A_1 (trivial)
A1 = ones(24, 1);
% A_2: sign of permutation induced on the 4 body diagonals.
diags = [ 1 1 1; 1 1 -1; 1 -1 1; -1 1 1];
A2 = zeros(24, 1);
for g = 1:24
permuted = (R{g} * diags')';
idx = zeros(4, 1);
for ii = 1:4
for jj = 1:4
if max(abs(permuted(ii, :) - diags(jj, :))) < 1e-6 || ...
max(abs(permuted(ii, :) + diags(jj, :))) < 1e-6
idx(ii) = jj; break;
end
end
end
A2(g) = obj.permSign(idx);
end
% T_1: the rotation matrices (3-d)
% E and T_2: from the symmetric traceless rank-2 representation.
% 5-d basis: (xx-yy)/sqrt(2), (2zz-xx-yy)/sqrt(6), xy, xz, yz.
B = [1, -1, 0, 0, 0, 0;
-1, -1, 2, 0, 0, 0;
0, 0, 0, 1, 0, 0;
0, 0, 0, 0, 1, 0;
0, 0, 0, 0, 0, 1];
for r = 1:5
B(r, :) = B(r, :) / norm(B(r, :));
end
E_mats = cell(24, 1);
T2_mats = cell(24, 1);
for g = 1:24
Mfull = obj.sym2RepFull(R{g}); % (6,6)
M5 = B * Mfull * B'; % (5,5) traceless
E_mats{g} = M5(1:2, 1:2);
T2_mats{g} = M5(3:5, 3:5);
end
rows = zeros(24, 24);
for g = 1:24
row = [A1(g), A2(g)];
row = [row, E_mats{g}(1,1), E_mats{g}(1,2), E_mats{g}(2,1), E_mats{g}(2,2)];
T1g = R{g};
row = [row, T1g(1,1), T1g(1,2), T1g(1,3), ...
T1g(2,1), T1g(2,2), T1g(2,3), ...
T1g(3,1), T1g(3,2), T1g(3,3)];
T2g = T2_mats{g};
row = [row, T2g(1,1), T2g(1,2), T2g(1,3), ...
T2g(2,1), T2g(2,2), T2g(2,3), ...
T2g(3,1), T2g(3,2), T2g(3,3)];
rows(g, :) = row;
end
obj.F = complex(rows) / sqrt(24);
obj.Finv = inv(obj.F);
obj.irrepDims = [1, 1, 2, 3, 3];
end
function R = octahedralRotations(~)
% 24 rotation matrices: 1 identity + 9 face + 8 vertex + 6 edge.
% Rodrigues formula inlined.
function M = ax(axis, angle)
ax_n = axis(:) / norm(axis);
Kmat = [0, -ax_n(3), ax_n(2);
ax_n(3), 0, -ax_n(1);
-ax_n(2), ax_n(1), 0];
M = eye(3) + sin(angle)*Kmat + (1-cos(angle))*(Kmat*Kmat);
end
R = cell(24, 1);
R{1} = eye(3);
idx = 2;
axes_face = {[1;0;0], [0;1;0], [0;0;1]};
angles = [pi/2, pi, -pi/2];
for a = 1:3
for ang = angles
R{idx} = ax(axes_face{a}, ang);
idx = idx + 1;
end
end
diagsAxes = {[1;1;1], [1;1;-1], [1;-1;1], [-1;1;1]};
for a = 1:4
for ang = [2*pi/3, -2*pi/3]
R{idx} = ax(diagsAxes{a}, ang);
idx = idx + 1;
end
end
edgesAxes = {[1;1;0], [1;-1;0], [1;0;1], [1;0;-1], [0;1;1], [0;1;-1]};
for a = 1:6
R{idx} = ax(edgesAxes{a}, pi);
idx = idx + 1;
end
for k = 1:24
Rr = round(R{k});
if max(abs(R{k}(:) - Rr(:))) < 1e-6
R{k} = Rr;
end
end
end
function s = permSign(~, perm)
n = length(perm);
inv = 0;
for i = 1:n
for j = i+1:n
if perm(i) > perm(j); inv = inv + 1; end
end
end
if mod(inv, 2) == 0; s = 1; else; s = -1; end
end
function M = sym2RepFull(~, R)
% 6-d symmetric tensor representation in basis (xx,yy,zz,xy,xz,yz).
pairs = [1 1; 2 2; 3 3; 1 2; 1 3; 2 3];
M = zeros(6, 6);
for i = 1:6
a = pairs(i, 1); b = pairs(i, 2);
for j = 1:6
c = pairs(j, 1); d = pairs(j, 2);
if c == d
M(i, j) = R(a, c) * R(b, d);
else
M(i, j) = R(a, c) * R(b, d) + R(a, d) * R(b, c);
end
end
end
end
function obj = initProduct(obj, groups)
k = length(groups);
orders = cellfun(@(g) g.n, groups);
obj.n = prod(orders);
obj.isCyclic = false;
obj.G = zeros(obj.n, obj.n);
for a = 1:obj.n
a_idx = obj.linearToMultiIndex(a, orders);
for b = 1:obj.n
b_idx = obj.linearToMultiIndex(b, orders);
c_idx = zeros(1, k);
for i = 1:k
c_idx(i) = groups{i}.G(a_idx(i), b_idx(i));
end
obj.G(a, b) = obj.multiToLinearIndex(c_idx, orders);
end
end
obj.identityIdx = 1;
obj.F = groups{1}.F;
obj.Finv = groups{1}.Finv;
for i = 2:k
obj.F = kron(obj.F, groups{i}.F);
obj.Finv = kron(obj.Finv, groups{i}.Finv);
end
end
function idx = linearToMultiIndex(~, lin, orders)
k = length(orders);
idx = zeros(1, k);
lin = lin - 1;
for i = k:-1:1
idx(i) = mod(lin, orders(i)) + 1;
lin = floor(lin / orders(i));
end
end
function lin = multiToLinearIndex(~, idx, orders)
k = length(orders);
lin = 0;
for i = 1:k
lin = lin * orders(i) + (idx(i) - 1);
end
lin = lin + 1;
end
function obj = initFromTable(obj, multTable)
obj.n = size(multTable, 1);
obj.G = multTable;
obj.isCyclic = false;
% Detect whether the table represents an abelian group.
isAb = isequal(multTable, multTable.');
if isAb
% For abelian groups the standard DFT diagonalizes the
% regular representation up to a permutation; this is the
% conservative fallback.
obj.F = fft(eye(obj.n));
obj.Finv = conj(obj.F) / obj.n;
else
% Non-abelian: a correct generalized Fourier transform
% requires the irrep matrices, which aren't recoverable
% from the table alone without numerical block-
% diagonalization (not yet implemented in MATLAB).
% Use the dedicated initializers
% (StarGAlgebra('octahedral'), 'dihedral', 'quaternion')
% which carry the correct irrep matrices, or use the
% Python pipeline at python/large_scale/starg_torch which
% handles arbitrary non-abelian groups via per-irrep
% block decomposition.
warning('StarGAlgebra:NonAbelianTable', ...
['initFromTable: the supplied multiplication table ', ...
'is non-abelian; a correct generalized Fourier ', ...
'matrix requires the irrep matrices and is not ', ...
'derivable from the table alone in this MATLAB ', ...
'implementation. Falling back to the regular ', ...
'representation as F (only valid in spirit; ', ...
'incorrect for SVD/Fourier-domain operations). ', ...
'Prefer StarGAlgebra(''octahedral''), ', ...
'StarGAlgebra(''dihedral'', n), or ', ...
'StarGAlgebra(''quaternion'').']);
obj.F = fft(eye(obj.n));
obj.Finv = conj(obj.F) / obj.n;
end
end
%% Utility
function I = identity(obj, m)
I = zeros(m, m, obj.n);
I(:,:,obj.identityIdx) = eye(m);
end
function nrm = tensorNorm(~, A)
nrm = sqrt(sum(abs(A(:)).^2));
end
%% Benchmark
function benchmark(obj, sizes)
if nargin < 2
sizes = [10, 20, 50, 100];
end
fprintf('\n=== Performance Benchmark ===\n');
fprintf('Group: n=%d, Cyclic=%d, GPU=%d\n', obj.n, obj.isCyclic, obj.useGPU);
fprintf('%-10s %-15s %-15s %-10s\n', 'Size', 'Direct (s)', 'Optimized (s)', 'Speedup');
fprintf('%s\n', repmat('-', 1, 55));
for sz = sizes
A = randn(sz, sz, obj.n);
B = randn(sz, sz, obj.n);
obj.starG(A, B); % Warm up
if sz <= 30
tic;
for rep = 1:3
obj.starG_direct(A, B);
end
t_direct = toc / 3;
else
t_direct = NaN;
end
tic;
for rep = 1:10
obj.starG(A, B);
end
t_opt = toc / 10;
if ~isnan(t_direct)
fprintf('%-10d %-15.4f %-15.4f %-10.1fx\n', sz, t_direct, t_opt, t_direct/t_opt);
else
fprintf('%-10d %-15s %-15.4f %-10s\n', sz, 'N/A', t_opt, '-');
end
end
end
%% Verification Suite
function runAllTests(obj)
fprintf('========================================\n');
fprintf('StarGAlgebra Verification Suite\n');
fprintf('Group order: %d, Cyclic: %d, Abelian: %d\n', ...
obj.n, obj.isCyclic, obj.isAbelian);
fprintf('Identity at index: %d, GPU: %d\n', obj.identityIdx, obj.useGPU);
fprintf('========================================\n\n');
obj.testGroupAxioms();
obj.testConvolutionTensor();
obj.testConvolutionMethods();
obj.testStarGMethods();
obj.testConjugateTranspose();
obj.testIdentity();
obj.testAssociativity();
obj.testSVD();
fprintf('\n========================================\n');
fprintf('All tests completed.\n');
fprintf('========================================\n');
end
function testGroupAxioms(obj)
fprintf('Test 1: Group Axioms\n');
passed = true;
e = obj.identityIdx;
for a = 1:obj.n
if obj.G(e, a) ~= a || obj.G(a, e) ~= a
fprintf(' FAIL: Identity for %d\n', a);
passed = false;
end
end
for a = 1:obj.n
a_inv = obj.invTable(a);
if obj.G(a, a_inv) ~= e || obj.G(a_inv, a) ~= e
fprintf(' FAIL: Inverse for %d\n', a);
passed = false;
end
end
for t = 1:min(obj.n^3, 500)
a = randi(obj.n); b = randi(obj.n); c = randi(obj.n);
if obj.G(obj.G(a,b), c) ~= obj.G(a, obj.G(b,c))
fprintf(' FAIL: Associativity\n');
passed = false;
break;
end
end
if passed
fprintf(' PASS\n');
end
end
function testConvolutionTensor(obj)
fprintf('\nTest 2: Convolution Tensor\n');
passed = true;
for a = 1:obj.n
for b = 1:obj.n
c_exp = obj.G(a, b);
for c = 1:obj.n
if obj.convTensor(a, b, c) ~= (c == c_exp)
passed = false;
end
end
end
end
if passed
fprintf(' PASS\n');
else
fprintf(' FAIL\n');
end
end
function testConvolutionMethods(obj)
fprintf('\nTest 3: 1D Convolution\n');
a = randn(obj.n, 1);
b = randn(obj.n, 1);
c1 = obj.convolve_direct(a, b);
c2 = obj.convolve_inverse(a, b);
c3 = obj.convolve(a, b);
err1 = norm(c1 - c2) / norm(c1);
err2 = norm(c1 - c3) / norm(c1);
fprintf(' Direct vs Inverse: %.2e\n', err1);
fprintf(' Direct vs Optimized: %.2e\n', err2);
tol = 1e-6;
if err1 < tol && err2 < tol
fprintf(' PASS\n');
else
fprintf(' FAIL\n');
end
end
function testStarGMethods(obj)
fprintf('\nTest 4: StarG Product\n');
A = randn(3, 4, obj.n);
B = randn(4, 2, obj.n);
tic;
C_direct = obj.starG_direct(A, B);
t1 = toc;
tic;
C_main = obj.starG(A, B);
t2 = toc;
err = norm(C_direct(:) - C_main(:)) / norm(C_direct(:));
fprintf(' Direct vs Main: error=%.2e (%.4fs vs %.4fs)\n', err, t1, t2);
tol = 1e-6;
if err < tol
fprintf(' PASS\n');
else
fprintf(' FAIL\n');
end
end
function testConjugateTranspose(obj)
fprintf('\nTest 5: Conjugate Transpose\n');
A = randn(3, 4, obj.n) + 1i*randn(3, 4, obj.n);
Ah1 = obj.conjugateTranspose(A);
Ah2 = obj.conjugateTranspose_fast(A);
err = norm(Ah1(:) - Ah2(:)) / norm(Ah1(:));
if err < 1e-10
fprintf(' PASS (error=%.2e)\n', err);
else
fprintf(' FAIL (error=%.2e)\n', err);
end
end
function testIdentity(obj)
fprintf('\nTest 6: Identity\n');
m = 4;
I = obj.identity(m);
A = randn(m, m, obj.n);
IA = obj.starG(I, A);
AI = obj.starG(A, I);
err1 = norm(IA(:) - A(:)) / norm(A(:));
err2 = norm(AI(:) - A(:)) / norm(A(:));
fprintf(' ||I*A - A||/||A|| = %.2e\n', err1);
fprintf(' ||A*I - A||/||A|| = %.2e\n', err2);
tol = 1e-10;
if err1 < tol && err2 < tol
fprintf(' PASS\n');
else
fprintf(' FAIL\n');
end
end
function testAssociativity(obj)
fprintf('\nTest 7: Associativity\n');
A = randn(2, 3, obj.n);
B = randn(3, 2, obj.n);
C = randn(2, 2, obj.n);
AB = obj.starG(A, B);
ABC_left = obj.starG(AB, C);
BC = obj.starG(B, C);
ABC_right = obj.starG(A, BC);
err = norm(ABC_left(:) - ABC_right(:)) / norm(ABC_left(:));
if err < 1e-10
fprintf(' PASS (error=%.2e)\n', err);
else
fprintf(' FAIL (error=%.2e)\n', err);
end
end
function testSVD(obj)
fprintf('\nTest 8: SVD\n');
if ~obj.isCyclic
fprintf(' SKIP (non-cyclic)\n');
return;
end
A = randn(4, 3, obj.n);
[U, S, V] = obj.starG_SVD(A);
Vh = obj.conjugateTranspose(V);
A_rec = obj.starG(obj.starG(U, S), Vh);
err = norm(A(:) - A_rec(:)) / norm(A(:));
if err < 1e-10
fprintf(' PASS (error=%.2e)\n', err);
else
fprintf(' FAIL (error=%.2e)\n', err);
end
end
end
end
%% ========================================================================
%% ADDITIONAL METHODS FOR StarGAlgebra CLASS
%% Add these methods to your StarGAlgebra.m file
%% ========================================================================
function [U, S, V, sv_invariant] = starG_SVD_stable(obj, A)
%% starG_SVD_stable - Numerically stable ★_G-SVD with exact invariants
%%
%% This version returns both the full decomposition AND a set of
%% numerically exact invariant quantities.
%%
%% OUTPUTS:
%% U, S, V - Standard ★_G-SVD factors
%% sv_invariant - Struct containing exactly invariant quantities:
%% .magnitudes - |singular values| in each Fourier slice
%% .power_spec - Power spectrum of input
%% .trace_invs - Trace invariants
[l, m, n] = size(A);
minlm = min(l, m);
sv_invariant = struct();
if obj.isCyclic
% === Fourier domain computation ===
Ahat = fft(A, [], 3);
n_freq = floor(n/2) + 1; % One-sided spectrum
Uhat = zeros(l, minlm, n);
Shat = zeros(minlm, minlm, n);
Vhat = zeros(m, minlm, n);
% Store invariant singular value magnitudes
sv_mags = zeros(minlm, n_freq);
for k = 1:n
slice = Ahat(:,:,k);
[Uk, Sk, Vk] = svd(slice, 'econ');
kk = size(Uk, 2);
Uhat(:, 1:kk, k) = Uk;
Shat(1:kk, 1:kk, k) = Sk;
Vhat(:, 1:kk, k) = Vk;
% Store magnitudes for one-sided spectrum
if k <= n_freq
sv_mags(1:kk, k) = abs(diag(Sk));
end
end
U = ifft(Uhat, [], 3);
S = ifft(Shat, [], 3);
V = ifft(Vhat, [], 3);
% Sort singular values within each frequency for consistency
sv_invariant.magnitudes = sort(sv_mags, 1, 'descend');
else
% === Direct computation for non-cyclic groups ===
U = zeros(l, minlm, n);
S = zeros(minlm, minlm, n);
V = zeros(m, minlm, n);
sv_mags = zeros(minlm, n);
for g = 1:n
[Ug, Sg, Vg] = svd(A(:,:,g), 'econ');
kk = size(Ug, 2);
U(:, 1:kk, g) = Ug;
S(1:kk, 1:kk, g) = Sg;
V(:, 1:kk, g) = Vg;
sv_mags(1:kk, g) = diag(Sg);
end
sv_invariant.magnitudes = sort(sv_mags, 1, 'descend');
end
% === Additional invariants ===
% Power spectrum (exact)
sv_invariant.power_spec = abs(fft(A, [], 3)).^2;
% Trace invariants (exact)
sv_invariant.trace_invs = zeros(4, 1);
A_flat = reshape(A, [l*m, n]);
sv_invariant.trace_invs(1) = sum(A_flat(:).^2); % Frobenius norm²
for g = 1:n
Ag = A(:,:,g);
sv_invariant.trace_invs(2) = sv_invariant.trace_invs(2) + trace(Ag * Ag');
sv_invariant.trace_invs(3) = sv_invariant.trace_invs(3) + trace((Ag * Ag')^2);
end
sv_invariant.trace_invs(4) = sum(svd(A_flat)); % Nuclear norm
if isreal(A)
U = real(U);
S = real(S);
V = real(V);
end
end
function feat = extractInvariantFeatures_stable(obj, A)
%% extractInvariantFeatures_stable - Extract exactly invariant features
%%
%% Uses only algebraically exact invariants, with numerical safeguards.
[l, m, n] = size(A);
% Get stable SVD with invariants
[~, ~, ~, sv_inv] = obj.starG_SVD_stable(A);
% Collect features
feat = [];
% 1. Singular value magnitudes (flattened, sorted)
sv_flat = sv_inv.magnitudes(:)';
feat = [feat, sv_flat];
% 2. Power spectrum statistics
ps = sv_inv.power_spec;
ps_sum = sum(ps(:));
ps_max = max(ps(:));
ps_mean = mean(ps(:));
ps_std = std(ps(:));
feat = [feat, ps_sum, ps_max, ps_mean, ps_std];
% 3. Trace invariants
feat = [feat, sv_inv.trace_invs'];
% 4. Gram matrix eigenvalues for each slice (averaged)
eig_sum = zeros(1, l);
for g = 1:n
Ag = A(:,:,g);
eig_g = sort(real(eig(Ag * Ag')), 'descend');
if length(eig_g) < l
eig_g = [eig_g; zeros(l - length(eig_g), 1)];
end
eig_sum = eig_sum + eig_g(1:l)';
end
feat = [feat, eig_sum / n];
% Round for exact invariance
feat = round(feat, 12);
end