"""GroupAlgebra: PyTorch port of the ★_G algebra core. Caches the multiplication table, generalized Fourier matrix F_G, irrep dimensions, and the inverse table. Supports cyclic, dihedral, octahedral, Klein-4, quaternion, and direct product groups. The MATLAB reference implementation is core/StarGAlgebra.m; this version is a faithful torch port designed for batched GPU execution. """ from __future__ import annotations from typing import List, Optional, Tuple import numpy as np import torch class GroupAlgebra: def __init__( self, group_type: str, param=None, *, device: Optional[torch.device] = None, dtype: torch.dtype = torch.complex64, ): self.group_type = group_type.lower() self.device = device or torch.device("cpu") self.dtype = dtype if self.group_type == "cyclic": self._init_cyclic(int(param)) elif self.group_type == "dihedral": self._init_dihedral(int(param)) elif self.group_type == "klein4": self._init_klein4() elif self.group_type == "octahedral": self._init_octahedral() elif self.group_type == "product": self._init_product(param) else: raise ValueError(f"Unknown group type: {group_type}") self._build_conv_tensor() self._find_identity() self._build_inverse_table() self.is_abelian = bool(np.array_equal(self.G_table, self.G_table.T)) # ------------------------------------------------------------------ # Group constructors # ------------------------------------------------------------------ def _init_cyclic(self, n: int): self.n = n self.G_table = (np.arange(n)[:, None] + np.arange(n)[None, :]) % n # F_G = DFT_n; abelian, irrep dims all 1 F = np.fft.fft(np.eye(n)) / np.sqrt(n) self.F = torch.tensor(F, dtype=self.dtype, device=self.device) self.Finv = torch.linalg.inv(self.F) self.irrep_dims = [1] * n self.is_cyclic = True def _init_dihedral(self, n: int): # D_n: 2n elements. Encode rotations 0..n-1 as 0..n-1, reflections n..2n-1. order = 2 * n T = np.zeros((order, order), dtype=int) for i in range(order): for j in range(order): T[i, j] = self._dihedral_mul(i, j, n) self.n = order self.G_table = T self.F, self.irrep_dims = self._build_F_dihedral(n) self.F = self.F.to(self.device) self.Finv = torch.linalg.inv(self.F) self.is_cyclic = False def _dihedral_mul(self, i: int, j: int, n: int) -> int: ri, si = i % n, i // n rj, sj = j % n, j // n if si == 0 and sj == 0: return (ri + rj) % n if si == 0 and sj == 1: return n + (ri + rj) % n if si == 1 and sj == 0: return n + (ri - rj) % n return (ri - rj) % n def _build_F_dihedral(self, n: int): # Irreps of D_n: two 1-d (trivial, sign), and (n//2 - 1 + (n%2==0)) 2-d. order = 2 * n irrep_dims: List[int] = [1, 1] # 2-d irreps for k = 1, .., floor((n-1)/2) n_2d = (n - 1) // 2 irrep_dims += [2] * n_2d # If n even, two more 1-d irreps if n % 2 == 0: irrep_dims += [1, 1] # Build F by row-vectorization of irrep matrices at each group element rows = [] for g in range(order): r, s = g % n, g // n row = [] # trivial row.append(1.0 + 0.0j) # sign: +1 on rotations, -1 on reflections row.append(1.0 if s == 0 else -1.0) for k in range(1, n_2d + 1): ang = 2 * np.pi * k * r / n if s == 0: M = np.array([[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]], dtype=complex) else: M = np.array([[np.cos(ang), np.sin(ang)], [np.sin(ang), -np.cos(ang)]], dtype=complex) row.extend(M.flatten().tolist()) if n % 2 == 0: row.append((1.0 if s == 0 else -1.0) * (1.0 if r % 2 == 0 else -1.0)) row.append(1.0 if r % 2 == 0 else -1.0) rows.append(row) F = np.array(rows, dtype=complex) / np.sqrt(order) return torch.tensor(F, dtype=self.dtype), irrep_dims def _init_klein4(self): self.n = 4 T = np.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 3, 0, 1], [3, 2, 1, 0]]) self.G_table = T # Klein 4 ≅ Z_2 × Z_2; F = H ⊗ H / 2 H = np.array([[1, 1], [1, -1]], dtype=complex) F = np.kron(H, H) / 2.0 self.F = torch.tensor(F, dtype=self.dtype, device=self.device) self.Finv = torch.linalg.inv(self.F) self.irrep_dims = [1, 1, 1, 1] self.is_cyclic = False def _init_octahedral(self): from .octahedral import octahedral_group, octahedral_irreps T, _rotations = octahedral_group() self.n = 24 self.G_table = T F, irrep_dims = octahedral_irreps() self.F = torch.tensor(F, dtype=self.dtype, device=self.device) self.Finv = torch.linalg.inv(self.F) self.irrep_dims = irrep_dims self.is_cyclic = False def _init_product(self, factors: List["GroupAlgebra"]): if not factors: raise ValueError("product group needs at least one factor") n = int(np.prod([f.n for f in factors])) self.n = n # Build multiplication table by encoding multi-indices sizes = [f.n for f in factors] idx = np.array(np.meshgrid(*[np.arange(s) for s in sizes], indexing="ij")) flat = idx.reshape(len(sizes), -1).T # n × d T = np.zeros((n, n), dtype=int) for i in range(n): for j in range(n): out = tuple( f.G_table[flat[i, k], flat[j, k]] for k, f in enumerate(factors) ) T[i, j] = np.ravel_multi_index(out, sizes) self.G_table = T # F = F_1 ⊗ F_2 ⊗ ... (Theorem 2 from main text) F = factors[0].F for f in factors[1:]: F = torch.kron(F, f.F) self.F = F.to(self.device) self.Finv = torch.linalg.inv(self.F) # irrep dims of product: multiset of products irrep_dims = factors[0].irrep_dims for f in factors[1:]: irrep_dims = [a * b for a in irrep_dims for b in f.irrep_dims] self.irrep_dims = irrep_dims self.is_cyclic = all(f.is_cyclic for f in factors) # ------------------------------------------------------------------ # Common # ------------------------------------------------------------------ def _build_conv_tensor(self): ng = self.n T = np.zeros((ng, ng, ng), dtype=np.float32) for a in range(ng): for b in range(ng): T[a, b, self.G_table[a, b]] = 1.0 self.conv_tensor = torch.tensor(T, device=self.device) def _find_identity(self): for e in range(self.n): if np.all(self.G_table[e, :] == np.arange(self.n)) and \ np.all(self.G_table[:, e] == np.arange(self.n)): self.identity_idx = e return raise RuntimeError("no identity element found in multiplication table") def _build_inverse_table(self): e = self.identity_idx inv = np.zeros(self.n, dtype=int) for a in range(self.n): for b in range(self.n): if self.G_table[a, b] == e: inv[a] = b break self.inv_table = inv def to(self, device: torch.device) -> "GroupAlgebra": self.device = device self.F = self.F.to(device) self.Finv = self.Finv.to(device) self.conv_tensor = self.conv_tensor.to(device) return self def __repr__(self): return ( f"GroupAlgebra({self.group_type}, n={self.n}, " f"abelian={self.is_abelian}, irrep_dims={self.irrep_dims})" )