tensor-group-sym / python / large_scale / starg_torch / algebra.py
algebra.py
Raw
"""GroupAlgebra: PyTorch port of the ★_G algebra core.

Caches the multiplication table, generalized Fourier matrix F_G, irrep
dimensions, and the inverse table. Supports cyclic, dihedral, octahedral,
Klein-4, quaternion, and direct product groups. The MATLAB reference
implementation is core/StarGAlgebra.m; this version is a faithful torch port
designed for batched GPU execution.
"""

from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np
import torch


class GroupAlgebra:
    def __init__(
        self,
        group_type: str,
        param=None,
        *,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.complex64,
    ):
        self.group_type = group_type.lower()
        self.device = device or torch.device("cpu")
        self.dtype = dtype

        if self.group_type == "cyclic":
            self._init_cyclic(int(param))
        elif self.group_type == "dihedral":
            self._init_dihedral(int(param))
        elif self.group_type == "klein4":
            self._init_klein4()
        elif self.group_type == "octahedral":
            self._init_octahedral()
        elif self.group_type == "product":
            self._init_product(param)
        else:
            raise ValueError(f"Unknown group type: {group_type}")

        self._build_conv_tensor()
        self._find_identity()
        self._build_inverse_table()
        self.is_abelian = bool(np.array_equal(self.G_table, self.G_table.T))

    # ------------------------------------------------------------------
    # Group constructors
    # ------------------------------------------------------------------

    def _init_cyclic(self, n: int):
        self.n = n
        self.G_table = (np.arange(n)[:, None] + np.arange(n)[None, :]) % n
        # F_G = DFT_n; abelian, irrep dims all 1
        F = np.fft.fft(np.eye(n)) / np.sqrt(n)
        self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
        self.Finv = torch.linalg.inv(self.F)
        self.irrep_dims = [1] * n
        self.is_cyclic = True

    def _init_dihedral(self, n: int):
        # D_n: 2n elements. Encode rotations 0..n-1 as 0..n-1, reflections n..2n-1.
        order = 2 * n
        T = np.zeros((order, order), dtype=int)
        for i in range(order):
            for j in range(order):
                T[i, j] = self._dihedral_mul(i, j, n)
        self.n = order
        self.G_table = T
        self.F, self.irrep_dims = self._build_F_dihedral(n)
        self.F = self.F.to(self.device)
        self.Finv = torch.linalg.inv(self.F)
        self.is_cyclic = False

    def _dihedral_mul(self, i: int, j: int, n: int) -> int:
        ri, si = i % n, i // n
        rj, sj = j % n, j // n
        if si == 0 and sj == 0:
            return (ri + rj) % n
        if si == 0 and sj == 1:
            return n + (ri + rj) % n
        if si == 1 and sj == 0:
            return n + (ri - rj) % n
        return (ri - rj) % n

    def _build_F_dihedral(self, n: int):
        # Irreps of D_n: two 1-d (trivial, sign), and (n//2 - 1 + (n%2==0)) 2-d.
        order = 2 * n
        irrep_dims: List[int] = [1, 1]
        # 2-d irreps for k = 1, .., floor((n-1)/2)
        n_2d = (n - 1) // 2
        irrep_dims += [2] * n_2d
        # If n even, two more 1-d irreps
        if n % 2 == 0:
            irrep_dims += [1, 1]
        # Build F by row-vectorization of irrep matrices at each group element
        rows = []
        for g in range(order):
            r, s = g % n, g // n
            row = []
            # trivial
            row.append(1.0 + 0.0j)
            # sign: +1 on rotations, -1 on reflections
            row.append(1.0 if s == 0 else -1.0)
            for k in range(1, n_2d + 1):
                ang = 2 * np.pi * k * r / n
                if s == 0:
                    M = np.array([[np.cos(ang), -np.sin(ang)],
                                  [np.sin(ang), np.cos(ang)]], dtype=complex)
                else:
                    M = np.array([[np.cos(ang), np.sin(ang)],
                                  [np.sin(ang), -np.cos(ang)]], dtype=complex)
                row.extend(M.flatten().tolist())
            if n % 2 == 0:
                row.append((1.0 if s == 0 else -1.0) * (1.0 if r % 2 == 0 else -1.0))
                row.append(1.0 if r % 2 == 0 else -1.0)
            rows.append(row)
        F = np.array(rows, dtype=complex) / np.sqrt(order)
        return torch.tensor(F, dtype=self.dtype), irrep_dims

    def _init_klein4(self):
        self.n = 4
        T = np.array([[0, 1, 2, 3],
                      [1, 0, 3, 2],
                      [2, 3, 0, 1],
                      [3, 2, 1, 0]])
        self.G_table = T
        # Klein 4 ≅ Z_2 × Z_2; F = H ⊗ H / 2
        H = np.array([[1, 1], [1, -1]], dtype=complex)
        F = np.kron(H, H) / 2.0
        self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
        self.Finv = torch.linalg.inv(self.F)
        self.irrep_dims = [1, 1, 1, 1]
        self.is_cyclic = False

    def _init_octahedral(self):
        from .octahedral import octahedral_group, octahedral_irreps
        T, _rotations = octahedral_group()
        self.n = 24
        self.G_table = T
        F, irrep_dims = octahedral_irreps()
        self.F = torch.tensor(F, dtype=self.dtype, device=self.device)
        self.Finv = torch.linalg.inv(self.F)
        self.irrep_dims = irrep_dims
        self.is_cyclic = False

    def _init_product(self, factors: List["GroupAlgebra"]):
        if not factors:
            raise ValueError("product group needs at least one factor")
        n = int(np.prod([f.n for f in factors]))
        self.n = n
        # Build multiplication table by encoding multi-indices
        sizes = [f.n for f in factors]
        idx = np.array(np.meshgrid(*[np.arange(s) for s in sizes], indexing="ij"))
        flat = idx.reshape(len(sizes), -1).T  # n × d
        T = np.zeros((n, n), dtype=int)
        for i in range(n):
            for j in range(n):
                out = tuple(
                    f.G_table[flat[i, k], flat[j, k]] for k, f in enumerate(factors)
                )
                T[i, j] = np.ravel_multi_index(out, sizes)
        self.G_table = T
        # F = F_1 ⊗ F_2 ⊗ ... (Theorem 2 from main text)
        F = factors[0].F
        for f in factors[1:]:
            F = torch.kron(F, f.F)
        self.F = F.to(self.device)
        self.Finv = torch.linalg.inv(self.F)
        # irrep dims of product: multiset of products
        irrep_dims = factors[0].irrep_dims
        for f in factors[1:]:
            irrep_dims = [a * b for a in irrep_dims for b in f.irrep_dims]
        self.irrep_dims = irrep_dims
        self.is_cyclic = all(f.is_cyclic for f in factors)

    # ------------------------------------------------------------------
    # Common
    # ------------------------------------------------------------------

    def _build_conv_tensor(self):
        ng = self.n
        T = np.zeros((ng, ng, ng), dtype=np.float32)
        for a in range(ng):
            for b in range(ng):
                T[a, b, self.G_table[a, b]] = 1.0
        self.conv_tensor = torch.tensor(T, device=self.device)

    def _find_identity(self):
        for e in range(self.n):
            if np.all(self.G_table[e, :] == np.arange(self.n)) and \
               np.all(self.G_table[:, e] == np.arange(self.n)):
                self.identity_idx = e
                return
        raise RuntimeError("no identity element found in multiplication table")

    def _build_inverse_table(self):
        e = self.identity_idx
        inv = np.zeros(self.n, dtype=int)
        for a in range(self.n):
            for b in range(self.n):
                if self.G_table[a, b] == e:
                    inv[a] = b
                    break
        self.inv_table = inv

    def to(self, device: torch.device) -> "GroupAlgebra":
        self.device = device
        self.F = self.F.to(device)
        self.Finv = self.Finv.to(device)
        self.conv_tensor = self.conv_tensor.to(device)
        return self

    def __repr__(self):
        return (
            f"GroupAlgebra({self.group_type}, n={self.n}, "
            f"abelian={self.is_abelian}, irrep_dims={self.irrep_dims})"
        )