"""GroupAlgebra: PyTorch port of the ★_G algebra core.
Caches the multiplication table, generalized Fourier matrix F_G, irrep
dimensions, and the inverse table. Supports cyclic, dihedral, octahedral,
Klein-4, quaternion, and direct product groups. The MATLAB reference
implementation is core/StarGAlgebra.m; this version is a faithful torch port
designed for batched GPU execution.
"""
from __future__ import annotations
from typing import List, Optional, Tuple
import numpy as np
import torch
class GroupAlgebra:
def __init__(
self,
group_type: str,
param=None,
*,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.complex64,
):
self.group_type = group_type.lower()
self.device = device or torch.device("cpu")
self.dtype = dtype
if self.group_type == "cyclic":
self._init_cyclic(int(param))
elif self.group_type == "dihedral":
self._init_dihedral(int(param))
elif self.group_type == "klein4":
self._init_klein4()
elif self.group_type == "octahedral":
self._init_octahedral()
elif self.group_type == "product":
self._init_product(param)
else:
raise ValueError(f"Unknown group type: {group_type}")
self._build_conv_tensor()
self._find_identity()
self._build_inverse_table()
self.is_abelian = bool(np.array_equal(self.G_table, self.G_table.T))
# ------------------------------------------------------------------
# Group constructors
# ------------------------------------------------------------------
def _init_cyclic(self, n: int):
self.n = n
self.G_table = (np.arange(n)[:, None] + np.arange(n)[None, :]) % n
# F_G = DFT_n; abelian, irrep dims all 1
F = np.fft.fft(np.eye(n)) / np.sqrt(n)
self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
self.Finv = torch.linalg.inv(self.F)
self.irrep_dims = [1] * n
self.is_cyclic = True
def _init_dihedral(self, n: int):
# D_n: 2n elements. Encode rotations 0..n-1 as 0..n-1, reflections n..2n-1.
order = 2 * n
T = np.zeros((order, order), dtype=int)
for i in range(order):
for j in range(order):
T[i, j] = self._dihedral_mul(i, j, n)
self.n = order
self.G_table = T
self.F, self.irrep_dims = self._build_F_dihedral(n)
self.F = self.F.to(self.device)
self.Finv = torch.linalg.inv(self.F)
self.is_cyclic = False
def _dihedral_mul(self, i: int, j: int, n: int) -> int:
ri, si = i % n, i // n
rj, sj = j % n, j // n
if si == 0 and sj == 0:
return (ri + rj) % n
if si == 0 and sj == 1:
return n + (ri + rj) % n
if si == 1 and sj == 0:
return n + (ri - rj) % n
return (ri - rj) % n
def _build_F_dihedral(self, n: int):
# Irreps of D_n: two 1-d (trivial, sign), and (n//2 - 1 + (n%2==0)) 2-d.
order = 2 * n
irrep_dims: List[int] = [1, 1]
# 2-d irreps for k = 1, .., floor((n-1)/2)
n_2d = (n - 1) // 2
irrep_dims += [2] * n_2d
# If n even, two more 1-d irreps
if n % 2 == 0:
irrep_dims += [1, 1]
# Build F by row-vectorization of irrep matrices at each group element
rows = []
for g in range(order):
r, s = g % n, g // n
row = []
# trivial
row.append(1.0 + 0.0j)
# sign: +1 on rotations, -1 on reflections
row.append(1.0 if s == 0 else -1.0)
for k in range(1, n_2d + 1):
ang = 2 * np.pi * k * r / n
if s == 0:
M = np.array([[np.cos(ang), -np.sin(ang)],
[np.sin(ang), np.cos(ang)]], dtype=complex)
else:
M = np.array([[np.cos(ang), np.sin(ang)],
[np.sin(ang), -np.cos(ang)]], dtype=complex)
row.extend(M.flatten().tolist())
if n % 2 == 0:
row.append((1.0 if s == 0 else -1.0) * (1.0 if r % 2 == 0 else -1.0))
row.append(1.0 if r % 2 == 0 else -1.0)
rows.append(row)
F = np.array(rows, dtype=complex) / np.sqrt(order)
return torch.tensor(F, dtype=self.dtype), irrep_dims
def _init_klein4(self):
self.n = 4
T = np.array([[0, 1, 2, 3],
[1, 0, 3, 2],
[2, 3, 0, 1],
[3, 2, 1, 0]])
self.G_table = T
# Klein 4 ≅ Z_2 × Z_2; F = H ⊗ H / 2
H = np.array([[1, 1], [1, -1]], dtype=complex)
F = np.kron(H, H) / 2.0
self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
self.Finv = torch.linalg.inv(self.F)
self.irrep_dims = [1, 1, 1, 1]
self.is_cyclic = False
def _init_octahedral(self):
from .octahedral import octahedral_group, octahedral_irreps
T, _rotations = octahedral_group()
self.n = 24
self.G_table = T
F, irrep_dims = octahedral_irreps()
self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
self.Finv = torch.linalg.inv(self.F)
self.irrep_dims = irrep_dims
self.is_cyclic = False
def _init_product(self, factors: List["GroupAlgebra"]):
if not factors:
raise ValueError("product group needs at least one factor")
n = int(np.prod([f.n for f in factors]))
self.n = n
# Build multiplication table by encoding multi-indices
sizes = [f.n for f in factors]
idx = np.array(np.meshgrid(*[np.arange(s) for s in sizes], indexing="ij"))
flat = idx.reshape(len(sizes), -1).T # n × d
T = np.zeros((n, n), dtype=int)
for i in range(n):
for j in range(n):
out = tuple(
f.G_table[flat[i, k], flat[j, k]] for k, f in enumerate(factors)
)
T[i, j] = np.ravel_multi_index(out, sizes)
self.G_table = T
# F = F_1 ⊗ F_2 ⊗ ... (Theorem 2 from main text)
F = factors[0].F
for f in factors[1:]:
F = torch.kron(F, f.F)
self.F = F.to(self.device)
self.Finv = torch.linalg.inv(self.F)
# irrep dims of product: multiset of products
irrep_dims = factors[0].irrep_dims
for f in factors[1:]:
irrep_dims = [a * b for a in irrep_dims for b in f.irrep_dims]
self.irrep_dims = irrep_dims
self.is_cyclic = all(f.is_cyclic for f in factors)
# ------------------------------------------------------------------
# Common
# ------------------------------------------------------------------
def _build_conv_tensor(self):
ng = self.n
T = np.zeros((ng, ng, ng), dtype=np.float32)
for a in range(ng):
for b in range(ng):
T[a, b, self.G_table[a, b]] = 1.0
self.conv_tensor = torch.tensor(T, device=self.device)
def _find_identity(self):
for e in range(self.n):
if np.all(self.G_table[e, :] == np.arange(self.n)) and \
np.all(self.G_table[:, e] == np.arange(self.n)):
self.identity_idx = e
return
raise RuntimeError("no identity element found in multiplication table")
def _build_inverse_table(self):
e = self.identity_idx
inv = np.zeros(self.n, dtype=int)
for a in range(self.n):
for b in range(self.n):
if self.G_table[a, b] == e:
inv[a] = b
break
self.inv_table = inv
def to(self, device: torch.device) -> "GroupAlgebra":
self.device = device
self.F = self.F.to(device)
self.Finv = self.Finv.to(device)
self.conv_tensor = self.conv_tensor.to(device)
return self
def __repr__(self):
return (
f"GroupAlgebra({self.group_type}, n={self.n}, "
f"abelian={self.is_abelian}, irrep_dims={self.irrep_dims})"
)