tensor-group-sym / python / StarGAlgebra.py
StarGAlgebra.py
Raw
"""
GPU-ACCELERATED StarGAlgebra CLASS
LH & SU 2026

Python conversion of MATLAB StarGAlgebra class for group-based
convolutional tensor algebra operations.
"""

import numpy as np
from numpy.fft import fft, ifft, fft2, ifft2
from scipy.linalg import dft, svd
from itertools import permutations
from typing import Optional, Tuple, List, Dict, Union, Any
import warnings

# Try to import CuPy for GPU support
try:
    import cupy as cp
    CUPY_AVAILABLE = True
except ImportError:
    CUPY_AVAILABLE = False
    cp = None


class StarGAlgebra:
    """
    StarGAlgebra: Group-based convolutional tensor algebra.
    
    Supports various finite groups including cyclic, dihedral, symmetric,
    Klein-4, quaternion, and direct products.
    
    Parameters
    ----------
    group_type : str
        Type of group: 'cyclic', 'dihedral', 'symmetric', 'klein4', 
        'quaternion', 'product', or 'table'
    group_param : optional
        Parameter for group initialization (e.g., order for cyclic groups)
    """
    
    def __init__(self, group_type: str, group_param: Any = None):
        # Properties
        self.G: np.ndarray = None  # Multiplication table
        self.n: int = 0  # Group order
        self.F: np.ndarray = None  # Fourier matrix
        self.Finv: np.ndarray = None  # Inverse Fourier matrix
        self.irrep_dims: List[int] = []  # Irreducible representation dimensions
        self.conv_tensor: np.ndarray = None  # Convolution tensor
        self.inv_table: np.ndarray = None  # Inverse table
        self.use_parallel: bool = False
        self.use_gpu: bool = False
        self.is_abelian: bool = False
        self.is_cyclic: bool = False
        self.identity_idx: int = 0
        
        # GPU arrays (None if not using GPU)
        self.G_gpu = None
        self.conv_tensor_gpu = None
        self.inv_table_gpu = None
        self.F_gpu = None
        self.Finv_gpu = None
        
        # Initialize based on group type
        group_type_lower = group_type.lower()
        
        if group_type_lower == 'cyclic':
            self._init_cyclic(group_param)
        elif group_type_lower == 'dihedral':
            self._init_dihedral(group_param)
        elif group_type_lower == 'symmetric':
            self._init_symmetric(group_param)
        elif group_type_lower == 'klein4':
            self._init_klein4()
        elif group_type_lower == 'quaternion':
            self._init_quaternion()
        elif group_type_lower == 'product':
            self._init_product(group_param)
        elif group_type_lower == 'table':
            self._init_from_table(group_param)
        else:
            raise ValueError(f"Unknown group type: {group_type}")
        
        # Build auxiliary structures
        self._build_convolution_tensor()
        self._build_inverse_table()
        self._find_identity()
        self.is_abelian = self._check_abelian()
    
    # =========================================================================
    # GPU Methods
    # =========================================================================
    
    def enable_gpu(self) -> 'StarGAlgebra':
        """Enable GPU acceleration using CuPy."""
        if CUPY_AVAILABLE:
            try:
                device = cp.cuda.Device()
                self.use_gpu = True
                self._init_gpu_arrays()
                print(f"GPU enabled: {device}")
            except Exception as e:
                warnings.warn(f"Failed to enable GPU: {e}")
        else:
            warnings.warn("CuPy not available. Install with: pip install cupy")
        return self
    
    def _init_gpu_arrays(self):
        """Initialize GPU arrays."""
        if self.use_gpu and CUPY_AVAILABLE:
            self.G_gpu = cp.asarray(self.G.astype(np.int32))
            self.conv_tensor_gpu = cp.asarray(self.conv_tensor)
            self.inv_table_gpu = cp.asarray(self.inv_table.astype(np.int32))
            self.F_gpu = cp.asarray(self.F)
            self.Finv_gpu = cp.asarray(self.Finv)
    
    def disable_gpu(self) -> 'StarGAlgebra':
        """Disable GPU acceleration."""
        self.use_gpu = False
        self.G_gpu = None
        self.conv_tensor_gpu = None
        self.inv_table_gpu = None
        self.F_gpu = None
        self.Finv_gpu = None
        return self
    
    def _to_gpu(self, arr: np.ndarray) -> Any:
        """Move array to GPU if GPU is enabled."""
        if self.use_gpu and CUPY_AVAILABLE:
            return cp.asarray(arr)
        return arr
    
    def _to_cpu(self, arr: Any) -> np.ndarray:
        """Move array to CPU."""
        if self.use_gpu and CUPY_AVAILABLE and isinstance(arr, cp.ndarray):
            return cp.asnumpy(arr)
        return np.asarray(arr)
    
    def _get_array_module(self, arr: Any):
        """Get the array module (numpy or cupy) for the given array."""
        if self.use_gpu and CUPY_AVAILABLE:
            return cp.get_array_module(arr)
        return np
    
    # =========================================================================
    # Core Setup Methods
    # =========================================================================
    
    def _find_identity(self):
        """Find the identity element index."""
        for e in range(self.n):
            is_id = True
            for a in range(self.n):
                if self.G[e, a] != a or self.G[a, e] != a:
                    is_id = False
                    break
            if is_id:
                self.identity_idx = e
                return
        raise ValueError("No identity element found")
    
    def _build_inverse_table(self):
        """Build the inverse element lookup table."""
        self.inv_table = np.zeros(self.n, dtype=np.int32)
        
        # Find identity element
        e = 0
        for candidate in range(self.n):
            if np.all(self.G[candidate, :] == np.arange(self.n)):
                e = candidate
                break
        
        # Find inverse for each element
        for a in range(self.n):
            for b in range(self.n):
                if self.G[a, b] == e:
                    self.inv_table[a] = b
                    break
    
    def _build_convolution_tensor(self):
        """Build the group convolution tensor."""
        ng = self.n
        self.conv_tensor = np.zeros((ng, ng, ng))
        
        for a in range(ng):
            for b in range(ng):
                c = self.G[a, b]
                self.conv_tensor[a, b, c] = 1
    
    def _check_abelian(self) -> bool:
        """Check if the group is abelian."""
        return np.array_equal(self.G, self.G.T)
    
    # =========================================================================
    # Convolution Methods
    # =========================================================================
    
    def convolve_direct(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        """Direct group convolution using convolution tensor."""
        a = a.flatten()
        b = b.flatten()
        
        if self.use_gpu and CUPY_AVAILABLE:
            a = cp.asarray(a)
            b = cp.asarray(b)
            T = self.conv_tensor_gpu
        else:
            T = self.conv_tensor
        
        xp = self._get_array_module(a)
        ab = xp.outer(a, b)
        c = xp.sum(xp.sum(T * ab[:, :, np.newaxis], axis=0), axis=0)
        
        return self._to_cpu(c)
    
    def convolve_inverse(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        """Group convolution using inverse formulation."""
        a = a.flatten()
        b = b.flatten()
        c = np.zeros(self.n)
        
        for c_idx in range(self.n):
            for g in range(self.n):
                g_inv = self.inv_table[g]
                g_inv_c = self.G[g_inv, c_idx]
                c[c_idx] += a[g] * b[g_inv_c]
        
        return c
    
    def convolve(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        """
        Optimized group convolution.
        
        Uses FFT for cyclic groups, direct method otherwise.
        """
        a = a.flatten()
        b = b.flatten()
        
        if self.is_cyclic:
            if self.use_gpu and CUPY_AVAILABLE:
                a = cp.asarray(a)
                b = cp.asarray(b)
                c = cp.fft.ifft(cp.fft.fft(a) * cp.fft.fft(b))
                c = cp.asnumpy(c)
            else:
                c = ifft(fft(a) * fft(b))
            
            if np.isreal(a).all() and np.isreal(b).all():
                c = np.real(c)
        else:
            c = self.convolve_direct(a, b)
        
        return c
    
    # =========================================================================
    # Star-G Product Methods
    # =========================================================================
    
    def star_g_direct(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
        """Direct computation of ★_G product."""
        l, m, ng = A.shape
        _, p, _ = B.shape
        
        if self.use_gpu and CUPY_AVAILABLE:
            A = cp.asarray(A)
            B = cp.asarray(B)
            T = self.conv_tensor_gpu
            C = cp.zeros((l, p, ng))
            xp = cp
        else:
            T = self.conv_tensor
            C = np.zeros((l, p, ng))
            xp = np
        
        for i in range(l):
            for j in range(p):
                for k in range(m):
                    a_ik = A[i, k, :]
                    b_kj = B[k, j, :]
                    ab = xp.outer(a_ik, b_kj)
                    conv_result = xp.sum(xp.sum(T * ab[:, :, np.newaxis], axis=0), axis=0)
                    C[i, j, :] += conv_result
        
        return self._to_cpu(C)
    
    def star_g_fourier(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
        """FFT-based ★_G product for cyclic groups."""
        l, m, n = A.shape
        _, p, _ = B.shape
        
        if self.use_gpu and CUPY_AVAILABLE:
            A = cp.asarray(A)
            B = cp.asarray(B)
            xp = cp
        else:
            xp = np
        
        Ahat = xp.fft.fft(A, axis=2)
        Bhat = xp.fft.fft(B, axis=2)
        
        # Matrix multiply for each frequency slice
        Chat = xp.zeros((l, p, n), dtype=Ahat.dtype)
        for k in range(n):
            Chat[:, :, k] = Ahat[:, :, k] @ Bhat[:, :, k]
        
        C = xp.fft.ifft(Chat, axis=2)
        
        C = self._to_cpu(C)
        
        if np.isreal(A).all() and np.isreal(B).all():
            C = np.real(C)
        
        return C
    
    def star_g(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
        """
        ★_G product - Convolutional tensor product.
        
        Parameters
        ----------
        A : ndarray of shape (l, m, n)
        B : ndarray of shape (m, p, n)
        
        Returns
        -------
        C : ndarray of shape (l, p, n)
        """
        # Handle 2D inputs
        if A.ndim == 2:
            A = A[:, :, np.newaxis]
            if A.shape[2] == 1 and self.n > 1:
                A = np.tile(A, (1, 1, self.n))
        
        if B.ndim == 2:
            B = B[:, :, np.newaxis]
            if B.shape[2] == 1 and self.n > 1:
                B = np.tile(B, (1, 1, self.n))
        
        _, m, n_A = A.shape
        m2, _, n_B = B.shape
        
        assert m == m2, "Inner dimensions must match"
        assert n_A == self.n and n_B == self.n, "Mode-3 must equal group order"
        
        if self.is_cyclic:
            return self.star_g_fourier(A, B)
        else:
            return self.star_g_direct(A, B)
    
    def star_g_batch(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
        """Batch ★_G product for 4D tensors."""
        l, m, ng, batch = A.shape
        _, p, _, _ = B.shape
        
        if self.use_gpu and CUPY_AVAILABLE and self.is_cyclic:
            A = cp.asarray(A)
            B = cp.asarray(B)
            
            Ahat = cp.fft.fft(A, axis=2)
            Bhat = cp.fft.fft(B, axis=2)
            
            Chat = cp.zeros((l, p, ng, batch), dtype=Ahat.dtype)
            
            for b_idx in range(batch):
                for k in range(ng):
                    Chat[:, :, k, b_idx] = Ahat[:, :, k, b_idx] @ Bhat[:, :, k, b_idx]
            
            C = cp.asnumpy(np.real(cp.fft.ifft(Chat, axis=2)))
        else:
            C = np.zeros((l, p, ng, batch))
            for b_idx in range(batch):
                C[:, :, :, b_idx] = self.star_g(A[:, :, :, b_idx], B[:, :, :, b_idx])
        
        return C
    
    def star_g_old(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
        """
        Original ★_G product implementation using group Fourier transform.
        
        Uses the full convolution tensor structure.
        """
        # Handle 2D inputs
        if A.ndim == 2:
            A = A[:, :, np.newaxis]
            if A.shape[2] == 1 and self.n > 1:
                A = np.tile(A, (1, 1, self.n))
        
        if B.ndim == 2:
            B = B[:, :, np.newaxis]
            if B.shape[2] == 1 and self.n > 1:
                B = np.tile(B, (1, 1, self.n))
        
        l, m, _ = A.shape
        m2, p, _ = B.shape
        n = self.n
        
        assert m == m2, "Inner dimensions must match"
        
        # Transform along mode-3 using group Fourier matrix
        A_reshape = A.transpose(2, 0, 1).reshape(n, -1)  # n x (l*m)
        Ahat_reshape = self.F.conj().T @ A_reshape  # n x (l*m)
        Ahat = Ahat_reshape.reshape(n, l, m).transpose(1, 2, 0)  # l x m x n
        
        B_reshape = B.transpose(2, 0, 1).reshape(n, -1)  # n x (m*p)
        Bhat_reshape = self.F.conj().T @ B_reshape  # n x (m*p)
        Bhat = Bhat_reshape.reshape(n, m, p).transpose(1, 2, 0)  # m x p x n
        
        # Multiply in transform domain with Peter-Weyl structure
        Chat = np.zeros((l, p, n), dtype=complex)
        
        for k in range(n):
            Ck_slice = np.zeros((l, p), dtype=complex)
            for x in range(n):
                for y in range(n):
                    if self.conv_tensor[x, y, k] != 0:
                        Ck_slice += self.conv_tensor[x, y, k] * (Ahat[:, :, x] @ Bhat[:, :, y])
            Chat[:, :, k] = Ck_slice
        
        # Transform back
        Chat_reshape = Chat.transpose(2, 0, 1).reshape(n, -1)  # n x (l*p)
        C_reshape = self.Finv.conj().T @ Chat_reshape  # n x (l*p)
        C = C_reshape.reshape(n, l, p).transpose(1, 2, 0)  # l x p x n
        
        if np.isreal(A).all() and np.isreal(B).all():
            C = np.real(C)
        
        return C
    
    # =========================================================================
    # Conjugate Transpose
    # =========================================================================
    
    def conjugate_transpose(self, A: np.ndarray) -> np.ndarray:
        """Compute ★_G conjugate transpose."""
        m, p, n = A.shape
        Ah = np.zeros((p, m, n), dtype=A.dtype)
        
        for i in range(p):
            for j in range(m):
                for g in range(n):
                    g_inv = self.inv_table[g]
                    Ah[i, j, g] = np.conj(A[j, i, g_inv])
        
        return Ah
    
    def conjugate_transpose_fast(self, A: np.ndarray) -> np.ndarray:
        """Fast ★_G conjugate transpose using vectorized operations."""
        Ah = np.conj(A).transpose(1, 0, 2)
        Ah = Ah[:, :, self.inv_table]
        return Ah
    
    # =========================================================================
    # SVD Methods
    # =========================================================================
    
    def star_g_svd(self, A: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Compute ★_G-SVD decomposition.
        
        Parameters
        ----------
        A : ndarray of shape (l, m, n)
        
        Returns
        -------
        U : ndarray of shape (l, min(l,m), n)
        S : ndarray of shape (min(l,m), min(l,m), n)
        V : ndarray of shape (m, min(l,m), n)
        """
        l, m, n = A.shape
        minlm = min(l, m)
        
        if self.is_cyclic:
            if self.use_gpu and CUPY_AVAILABLE:
                A = cp.asarray(A)
                Ahat = cp.fft.fft(A, axis=2)
                Ahat = cp.asnumpy(Ahat)
            else:
                Ahat = fft(A, axis=2)
            
            Uhat = np.zeros((l, minlm, n), dtype=complex)
            Shat = np.zeros((minlm, minlm, n), dtype=complex)
            Vhat = np.zeros((m, minlm, n), dtype=complex)
            
            for i in range(n):
                slice_i = Ahat[:, :, i]
                Ui, Si, Vhi = svd(slice_i, full_matrices=False)
                k = min(Ui.shape[1], minlm)
                Uhat[:, :k, i] = Ui[:, :k]
                Shat[:k, :k, i] = np.diag(Si[:k])
                Vhat[:, :k, i] = Vhi[:k, :].T
            
            U = ifft(Uhat, axis=2)
            S = ifft(Shat, axis=2)
            V = ifft(Vhat, axis=2)
        else:
            U = np.zeros((l, minlm, n), dtype=complex)
            S = np.zeros((minlm, minlm, n), dtype=complex)
            V = np.zeros((m, minlm, n), dtype=complex)
            
            for g in range(n):
                Ug, Sg, Vhg = svd(A[:, :, g], full_matrices=False)
                k = min(Ug.shape[1], minlm)
                U[:, :k, g] = Ug[:, :k]
                S[:k, :k, g] = np.diag(Sg[:k])
                V[:, :k, g] = Vhg[:k, :].T
        
        if np.isreal(A).all():
            U = np.real(U)
            S = np.real(S)
            V = np.real(V)
        
        return U, S, V
    
    def truncate(self, A: np.ndarray, k: int) -> np.ndarray:
        """Truncated ★_G-SVD reconstruction."""
        l, m, n = A.shape
        k = min(k, min(l, m))
        
        if self.is_cyclic:
            if self.use_gpu and CUPY_AVAILABLE:
                A = cp.asarray(A)
                Ahat = cp.fft.fft(A, axis=2)
                Ahat = cp.asnumpy(Ahat)
            else:
                Ahat = fft(A, axis=2)
            
            Akhat = np.zeros((l, m, n), dtype=complex)
            
            for i in range(n):
                slice_i = Ahat[:, :, i]
                Ui, Si, Vhi = svd(slice_i, full_matrices=False)
                ki = min(k, len(Si))
                Akhat[:, :, i] = Ui[:, :ki] @ np.diag(Si[:ki]) @ Vhi[:ki, :]
            
            Ak = ifft(Akhat, axis=2)
        else:
            Ak = np.zeros((l, m, n), dtype=complex)
            
            for g in range(n):
                Ug, Sg, Vhg = svd(A[:, :, g], full_matrices=False)
                kg = min(k, len(Sg))
                Ak[:, :, g] = Ug[:, :kg] @ np.diag(Sg[:kg]) @ Vhg[:kg, :]
        
        if np.isreal(A).all():
            Ak = np.real(Ak)
        
        return Ak
    
    # =========================================================================
    # Group Initialization Methods
    # =========================================================================
    
    def _init_cyclic(self, n: int):
        """Initialize cyclic group Z_n."""
        self.n = n
        self.is_cyclic = True
        I, J = np.meshgrid(np.arange(n), np.arange(n))
        self.G = (I + J) % n
        self.identity_idx = 0
        self.F = dft(n)
        self.Finv = np.conj(self.F) / n
    
    def _init_klein4(self):
        """Initialize Klein four-group."""
        self.n = 4
        self.is_cyclic = False
        self.G = np.array([
            [0, 1, 2, 3],
            [1, 0, 3, 2],
            [2, 3, 0, 1],
            [3, 2, 1, 0]
        ])
        self.identity_idx = 0
        self.F = np.array([
            [1, 1, 1, 1],
            [1, 1, -1, -1],
            [1, -1, 1, -1],
            [1, -1, -1, 1]
        ], dtype=float)
        self.Finv = self.F / 4
    
    def _init_dihedral(self, n: int):
        """Initialize dihedral group D_n."""
        self.n = 2 * n
        self.is_cyclic = False
        self.G = np.zeros((2*n, 2*n), dtype=int)
        
        for i in range(2*n):
            for j in range(2*n):
                if i < n and j < n:
                    self.G[i, j] = (i + j) % n
                elif i < n and j >= n:
                    self.G[i, j] = ((j - n) + i) % n + n
                elif i >= n and j < n:
                    self.G[i, j] = ((i - n) - j) % n + n
                else:
                    self.G[i, j] = ((i - n) - (j - n)) % n
        
        self.identity_idx = 0
        self.F = dft(2*n)
        self.Finv = np.conj(self.F) / (2*n)
    
    def _init_symmetric(self, n: int):
        """Initialize symmetric group S_n."""
        self.is_cyclic = False
        perms_list = list(permutations(range(n)))
        self.n = len(perms_list)
        
        # Move identity to front
        identity_perm = tuple(range(n))
        identity_idx = perms_list.index(identity_perm)
        if identity_idx != 0:
            perms_list[0], perms_list[identity_idx] = perms_list[identity_idx], perms_list[0]
        
        # Create permutation to index mapping
        perm_to_idx = {perm: i for i, perm in enumerate(perms_list)}
        
        # Build multiplication table
        self.G = np.zeros((self.n, self.n), dtype=int)
        for i in range(self.n):
            for j in range(self.n):
                # Compose permutations: (perms_list[i] ∘ perms_list[j])
                composed = tuple(perms_list[i][k] for k in perms_list[j])
                self.G[i, j] = perm_to_idx[composed]
        
        self.identity_idx = 0
        self.F = dft(self.n)
        self.Finv = np.conj(self.F) / self.n
    
    def _init_quaternion(self):
        """Initialize quaternion group Q_8."""
        self.n = 8
        self.is_cyclic = False
        # Cayley table for Q_8 = {1, -1, i, -i, j, -j, k, -k}
        # Using 0-indexed: {0, 1, 2, 3, 4, 5, 6, 7}
        self.G = np.array([
            [0, 1, 2, 3, 4, 5, 6, 7],
            [1, 0, 3, 2, 5, 4, 7, 6],
            [2, 3, 1, 0, 6, 7, 5, 4],
            [3, 2, 0, 1, 7, 6, 4, 5],
            [4, 5, 7, 6, 1, 0, 2, 3],
            [5, 4, 6, 7, 0, 1, 3, 2],
            [6, 7, 4, 5, 3, 2, 1, 0],
            [7, 6, 5, 4, 2, 3, 0, 1]
        ])
        self.identity_idx = 0
        self.F = dft(8)
        self.Finv = np.conj(self.F) / 8
    
    def _init_product(self, groups: List['StarGAlgebra']):
        """Initialize direct product of groups."""
        k = len(groups)
        orders = [g.n for g in groups]
        self.n = int(np.prod(orders))
        self.is_cyclic = False
        
        self.G = np.zeros((self.n, self.n), dtype=int)
        
        for a in range(self.n):
            a_idx = self._linear_to_multi_index(a, orders)
            for b in range(self.n):
                b_idx = self._linear_to_multi_index(b, orders)
                c_idx = [groups[i].G[a_idx[i], b_idx[i]] for i in range(k)]
                self.G[a, b] = self._multi_to_linear_index(c_idx, orders)
        
        self.identity_idx = 0
        
        # Kronecker product of Fourier matrices
        self.F = groups[0].F
        self.Finv = groups[0].Finv
        for i in range(1, k):
            self.F = np.kron(self.F, groups[i].F)
            self.Finv = np.kron(self.Finv, groups[i].Finv)
    
    def _linear_to_multi_index(self, lin: int, orders: List[int]) -> List[int]:
        """Convert linear index to multi-index."""
        k = len(orders)
        idx = [0] * k
        for i in range(k-1, -1, -1):
            idx[i] = lin % orders[i]
            lin = lin // orders[i]
        return idx
    
    def _multi_to_linear_index(self, idx: List[int], orders: List[int]) -> int:
        """Convert multi-index to linear index."""
        lin = 0
        for i in range(len(orders)):
            lin = lin * orders[i] + idx[i]
        return lin
    
    def _init_from_table(self, mult_table: np.ndarray):
        """Initialize from multiplication table."""
        self.n = mult_table.shape[0]
        self.G = mult_table.astype(int)
        self.is_cyclic = False
        self.F = dft(self.n)
        self.Finv = np.conj(self.F) / self.n
    
    # =========================================================================
    # Utility Methods
    # =========================================================================
    
    def identity_tensor(self, m: int) -> np.ndarray:
        """Create identity tensor of size m x m x n."""
        I = np.zeros((m, m, self.n))
        I[:, :, self.identity_idx] = np.eye(m)
        return I
    
    def tensor_norm(self, A: np.ndarray) -> float:
        """Compute Frobenius norm of tensor."""
        return np.sqrt(np.sum(np.abs(A) ** 2))
    
    # =========================================================================
    # Benchmark
    # =========================================================================
    
    def benchmark(self, sizes: List[int] = None):
        """Run performance benchmark."""
        import time
        
        if sizes is None:
            sizes = [10, 20, 50, 100]
        
        print("\n=== Performance Benchmark ===")
        print(f"Group: n={self.n}, Cyclic={self.is_cyclic}, GPU={self.use_gpu}")
        print(f"{'Size':<10} {'Direct (s)':<15} {'Optimized (s)':<15} {'Speedup':<10}")
        print("-" * 55)
        
        for sz in sizes:
            A = np.random.randn(sz, sz, self.n)
            B = np.random.randn(sz, sz, self.n)
            
            # Warm up
            _ = self.star_g(A, B)
            
            # Direct timing
            if sz <= 30:
                start = time.time()
                for _ in range(3):
                    self.star_g_direct(A, B)
                t_direct = (time.time() - start) / 3
            else:
                t_direct = float('nan')
            
            # Optimized timing
            start = time.time()
            for _ in range(10):
                self.star_g_old(A, B)
            t_opt = (time.time() - start) / 10
            
            if not np.isnan(t_direct):
                print(f"{sz:<10} {t_direct:<15.4f} {t_opt:<15.4f} {t_direct/t_opt:<10.1f}x")
            else:
                print(f"{sz:<10} {'N/A':<15} {t_opt:<15.4f} {'-':<10}")
    
    # =========================================================================
    # Verification Suite
    # =========================================================================
    
    def run_all_tests(self):
        """Run complete verification suite."""
        print("=" * 40)
        print("StarGAlgebra Verification Suite")
        print(f"Group order: {self.n}, Cyclic: {self.is_cyclic}, Abelian: {self.is_abelian}")
        print(f"Identity at index: {self.identity_idx}, GPU: {self.use_gpu}")
        print("=" * 40 + "\n")
        
        self._test_group_axioms()
        self._test_convolution_tensor()
        self._test_convolution_methods()
        self._test_star_g_methods()
        self._test_conjugate_transpose()
        self._test_identity()
        self._test_associativity()
        self._test_svd()
        
        print("\n" + "=" * 40)
        print("All tests completed.")
        print("=" * 40)
    
    def _test_group_axioms(self):
        """Test group axioms."""
        print("Test 1: Group Axioms")
        passed = True
        e = self.identity_idx
        
        # Test identity
        for a in range(self.n):
            if self.G[e, a] != a or self.G[a, e] != a:
                print(f"  FAIL: Identity for {a}")
                passed = False
        
        # Test inverse
        for a in range(self.n):
            a_inv = self.inv_table[a]
            if self.G[a, a_inv] != e or self.G[a_inv, a] != e:
                print(f"  FAIL: Inverse for {a}")
                passed = False
        
        # Test associativity (sample)
        for _ in range(min(self.n**3, 500)):
            a, b, c = np.random.randint(0, self.n, 3)
            if self.G[self.G[a, b], c] != self.G[a, self.G[b, c]]:
                print("  FAIL: Associativity")
                passed = False
                break
        
        if passed:
            print("  PASS")
    
    def _test_convolution_tensor(self):
        """Test convolution tensor."""
        print("\nTest 2: Convolution Tensor")
        passed = True
        
        for a in range(self.n):
            for b in range(self.n):
                c_exp = self.G[a, b]
                for c in range(self.n):
                    if self.conv_tensor[a, b, c] != (c == c_exp):
                        passed = False
        
        print("  PASS" if passed else "  FAIL")
    
    def _test_convolution_methods(self):
        """Test 1D convolution methods."""
        print("\nTest 3: 1D Convolution")
        
        a = np.random.randn(self.n)
        b = np.random.randn(self.n)
        
        c1 = self.convolve_direct(a, b)
        c2 = self.convolve_inverse(a, b)
        c3 = self.convolve(a, b)
        
        err1 = np.linalg.norm(c1 - c2) / np.linalg.norm(c1)
        err2 = np.linalg.norm(c1 - c3) / np.linalg.norm(c1)
        
        print(f"  Direct vs Inverse: {err1:.2e}")
        print(f"  Direct vs Optimized: {err2:.2e}")
        
        tol = 1e-6
        print("  PASS" if err1 < tol and err2 < tol else "  FAIL")
    
    def _test_star_g_methods(self):
        """Test StarG product methods."""
        print("\nTest 4: StarG Product")
        
        import time
        
        A = np.random.randn(3, 4, self.n)
        B = np.random.randn(4, 2, self.n)
        
        start = time.time()
        C_direct = self.star_g_direct(A, B)
        t1 = time.time() - start
        
        start = time.time()
        C_main = self.star_g_old(A, B)
        t2 = time.time() - start
        
        err = np.linalg.norm(C_direct - C_main) / np.linalg.norm(C_direct)
        print(f"  Direct vs Main: error={err:.2e} ({t1:.4f}s vs {t2:.4f}s)")
        
        tol = 1e-6
        print("  PASS" if err < tol else "  FAIL")
    
    def _test_conjugate_transpose(self):
        """Test conjugate transpose."""
        print("\nTest 5: Conjugate Transpose")
        
        A = np.random.randn(3, 4, self.n) + 1j * np.random.randn(3, 4, self.n)
        
        Ah1 = self.conjugate_transpose(A)
        Ah2 = self.conjugate_transpose_fast(A)
        
        err = np.linalg.norm(Ah1 - Ah2) / np.linalg.norm(Ah1)
        
        print(f"  PASS (error={err:.2e})" if err < 1e-10 else f"  FAIL (error={err:.2e})")
    
    def _test_identity(self):
        """Test identity tensor."""
        print("\nTest 6: Identity")
        
        m = 4
        I = self.identity_tensor(m)
        A = np.random.randn(m, m, self.n)
        
        IA = self.star_g(I, A)
        AI = self.star_g(A, I)
        
        err1 = np.linalg.norm(IA - A) / np.linalg.norm(A)
        err2 = np.linalg.norm(AI - A) / np.linalg.norm(A)
        
        print(f"  ||I*A - A||/||A|| = {err1:.2e}")
        print(f"  ||A*I - A||/||A|| = {err2:.2e}")
        
        tol = 1e-10
        print("  PASS" if err1 < tol and err2 < tol else "  FAIL")
    
    def _test_associativity(self):
        """Test associativity of StarG product."""
        print("\nTest 7: Associativity")
        
        A = np.random.randn(2, 3, self.n)
        B = np.random.randn(3, 2, self.n)
        C = np.random.randn(2, 2, self.n)
        
        AB = self.star_g(A, B)
        ABC_left = self.star_g(AB, C)
        
        BC = self.star_g(B, C)
        ABC_right = self.star_g(A, BC)
        
        err = np.linalg.norm(ABC_left - ABC_right) / np.linalg.norm(ABC_left)
        
        print(f"  PASS (error={err:.2e})" if err < 1e-10 else f"  FAIL (error={err:.2e})")
    
    def _test_svd(self):
        """Test SVD decomposition."""
        print("\nTest 8: SVD")
        
        if not self.is_cyclic:
            print("  SKIP (non-cyclic)")
            return
        
        A = np.random.randn(4, 3, self.n)
        U, S, V = self.star_g_svd(A)
        
        Vh = self.conjugate_transpose(V)
        A_rec = self.star_g(self.star_g(U, S), Vh)
        
        err = np.linalg.norm(A - A_rec) / np.linalg.norm(A)
        
        print(f"  PASS (error={err:.2e})" if err < 1e-10 else f"  FAIL (error={err:.2e})")


# =============================================================================
# Additional Stable SVD Functions
# =============================================================================

def star_g_svd_stable(algebra: StarGAlgebra, A: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
    """
    Numerically stable ★_G-SVD with exact invariants.
    
    Parameters
    ----------
    algebra : StarGAlgebra
        The group algebra instance
    A : ndarray of shape (l, m, n)
        Input tensor
    
    Returns
    -------
    U, S, V : ndarrays
        Standard ★_G-SVD factors
    sv_invariant : dict
        Contains exactly invariant quantities:
        - magnitudes: |singular values| in each Fourier slice
        - power_spec: Power spectrum of input
        - trace_invs: Trace invariants
    """
    l, m, n = A.shape
    minlm = min(l, m)
    
    sv_invariant = {}
    
    if algebra.is_cyclic:
        # Fourier domain computation
        Ahat = fft(A, axis=2)
        
        n_freq = n // 2 + 1  # One-sided spectrum
        
        Uhat = np.zeros((l, minlm, n), dtype=complex)
        Shat = np.zeros((minlm, minlm, n), dtype=complex)
        Vhat = np.zeros((m, minlm, n), dtype=complex)
        
        sv_mags = np.zeros((minlm, n_freq))
        
        for k in range(n):
            slice_k = Ahat[:, :, k]
            Uk, Sk, Vhk = svd(slice_k, full_matrices=False)
            
            kk = min(len(Sk), minlm)
            Uhat[:, :kk, k] = Uk[:, :kk]
            Shat[:kk, :kk, k] = np.diag(Sk[:kk])
            Vhat[:, :kk, k] = Vhk[:kk, :].T
            
            if k < n_freq:
                sv_mags[:kk, k] = np.abs(Sk[:kk])
        
        U = ifft(Uhat, axis=2)
        S = ifft(Shat, axis=2)
        V = ifft(Vhat, axis=2)
        
        sv_invariant['magnitudes'] = np.sort(sv_mags, axis=0)[::-1]
    
    else:
        # Direct computation for non-cyclic groups
        U = np.zeros((l, minlm, n), dtype=complex)
        S = np.zeros((minlm, minlm, n), dtype=complex)
        V = np.zeros((m, minlm, n), dtype=complex)
        
        sv_mags = np.zeros((minlm, n))
        
        for g in range(n):
            Ug, Sg, Vhg = svd(A[:, :, g], full_matrices=False)
            kk = min(len(Sg), minlm)
            U[:, :kk, g] = Ug[:, :kk]
            S[:kk, :kk, g] = np.diag(Sg[:kk])
            V[:, :kk, g] = Vhg[:kk, :].T
            sv_mags[:kk, g] = Sg[:kk]
        
        sv_invariant['magnitudes'] = np.sort(sv_mags, axis=0)[::-1]
    
    # Additional invariants
    sv_invariant['power_spec'] = np.abs(fft(A, axis=2)) ** 2
    
    # Trace invariants
    sv_invariant['trace_invs'] = np.zeros(4)
    A_flat = A.reshape(l * m, n)
    sv_invariant['trace_invs'][0] = np.sum(A_flat ** 2)  # Frobenius norm²
    
    for g in range(n):
        Ag = A[:, :, g]
        sv_invariant['trace_invs'][1] += np.trace(Ag @ Ag.T)
        sv_invariant['trace_invs'][2] += np.trace((Ag @ Ag.T) @ (Ag @ Ag.T))
    
    sv_invariant['trace_invs'][3] = np.sum(svd(A_flat, compute_uv=False))  # Nuclear norm
    
    if np.isreal(A).all():
        U = np.real(U)
        S = np.real(S)
        V = np.real(V)
    
    return U, S, V, sv_invariant


def extract_invariant_features_stable(algebra: StarGAlgebra, A: np.ndarray) -> np.ndarray:
    """
    Extract exactly invariant features from tensor.
    
    Uses only algebraically exact invariants with numerical safeguards.
    
    Parameters
    ----------
    algebra : StarGAlgebra
        The group algebra instance
    A : ndarray of shape (l, m, n)
        Input tensor
    
    Returns
    -------
    feat : ndarray
        Feature vector containing invariant quantities
    """
    l, m, n = A.shape
    
    # Get stable SVD with invariants
    _, _, _, sv_inv = star_g_svd_stable(algebra, A)
    
    feat = []
    
    # 1. Singular value magnitudes (flattened, sorted)
    sv_flat = sv_inv['magnitudes'].flatten()
    feat.extend(sv_flat)
    
    # 2. Power spectrum statistics
    ps = sv_inv['power_spec']
    ps_sum = np.sum(ps)
    ps_max = np.max(ps)
    ps_mean = np.mean(ps)
    ps_std = np.std(ps)
    feat.extend([ps_sum, ps_max, ps_mean, ps_std])
    
    # 3. Trace invariants
    feat.extend(sv_inv['trace_invs'])
    
    # 4. Gram matrix eigenvalues for each slice (averaged)
    eig_sum = np.zeros(l)
    for g in range(n):
        Ag = A[:, :, g]
        eig_g = np.sort(np.real(np.linalg.eigvals(Ag @ Ag.T)))[::-1]
        if len(eig_g) < l:
            eig_g = np.concatenate([eig_g, np.zeros(l - len(eig_g))])
        eig_sum += eig_g[:l]
    feat.extend(eig_sum / n)
    
    # Round for exact invariance
    feat = np.round(feat, 12)
    
    return np.array(feat)


# =============================================================================
# Example Usage
# =============================================================================

if __name__ == "__main__":
    # Create cyclic group algebra
    print("Testing Cyclic Group Z_5:")
    alg = StarGAlgebra('cyclic', 5)
    alg.run_all_tests()
    
    print("\n" + "="*60 + "\n")
    
    # Create Klein-4 group algebra
    print("Testing Klein-4 Group:")
    alg_k4 = StarGAlgebra('klein4')
    alg_k4.run_all_tests()
    
    print("\n" + "="*60 + "\n")
    
    # Create dihedral group algebra
    print("Testing Dihedral Group D_3:")
    alg_d3 = StarGAlgebra('dihedral', 3)
    alg_d3.run_all_tests()
    
    print("\n" + "="*60 + "\n")
    
    # Test stable SVD
    print("Testing Stable SVD:")
    A = np.random.randn(4, 3, 5)
    U, S, V, inv = star_g_svd_stable(alg, A)
    print(f"  Invariant magnitudes shape: {inv['magnitudes'].shape}")
    print(f"  Trace invariants: {inv['trace_invs']}")
    
    # Test feature extraction
    feat = extract_invariant_features_stable(alg, A)
    print(f"  Feature vector length: {len(feat)}")
    
    # Benchmark
    print("\n" + "="*60 + "\n")
    print("Running benchmark:")
    alg.benchmark([5, 10, 20])