tensor-group-sym / python / large_scale / starg_torch / svd.py
svd.py
Raw
"""Batched ★_G-SVD on GPU.

★_G-SVD (manuscript SI): Generalized-Fourier-domain SVD, compute a standard
matrix SVD per irrep block and transform back. For abelian groups every
block is 1×1, so this reduces to torch.linalg.svd along the group axis
after the generalized Fourier transform. For non-abelian groups (e.g.\ the
chiral octahedral group used in the Wigner--Eckart experiment) we do per-
irrep block SVD and reassemble the result into the standard
(U, S, Vh) Fourier-domain layout that ``starg_product`` consumes.

Derivation. Let A have shape (..., L, M, n) with n = |G|. Its generalized
Fourier transform Ah(:,:,k) re-indexes the n axis by ``F``: each k slice
along n is one entry of one irrep matrix, with ``irrep_dims`` and
row-vectorization fixing the layout. For irrep ρ of dimension d_ρ, the d_ρ²
slices at positions cursor:cursor+d_ρ² collectively form the (L*d_ρ, M*d_ρ)
"per-irrep block" matrix Â_ρ. SVD gives Â_ρ = U_ρ diag(σ_ρ) V_ρ^H with
K_ρ = min(L*d_ρ, M*d_ρ) = K * d_ρ singular values per irrep, where
K = min(L, M).

Reassembly. Reshape U_ρ ∈ (L*d_ρ, K*d_ρ) by splitting the row index as
(l, i) ∈ L × d_ρ and the column index as (k, p) ∈ K × d_ρ. Place the
(l, i, k, p)-cell at U_h[l, k, cursor + i*d_ρ + p]. The same shape
contract recovers the abelian path when d_ρ ≡ 1: U_h reduces to
(..., L, K, n) with no inner-block structure. Identical reshapes apply
to Vh; S becomes f-diagonal in the sense that
``S_h[k, k, cursor + i*d_ρ + i] = σ_{k, i}``, all other entries zero.
``U ★_G S ★_G Vh = A`` is then exact (up to numerical SVD precision)
for any group, abelian or not, by construction, the per-irrep
block-matmul rule of ``_block_diag_matmul`` in ``product.py`` is the
inverse of this construction.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Tuple

import torch

if TYPE_CHECKING:
    from .algebra import GroupAlgebra


def starg_svd(
    A: torch.Tensor,
    G: "GroupAlgebra",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """★_G-SVD of a batched tensor.

    Args:
        A : tensor of shape (..., L, M, n) where n = |G|.
        G : GroupAlgebra carrying F, Finv, irrep_dims, is_abelian.

    Returns:
        U  : (..., L, K, n)
        S  : (..., K, K, n), f-diagonal in the sense above.
        Vh : (..., K, M, n)
        with K = min(L, M).

    The product ``U ★_G S ★_G Vh`` reconstructs A.
    """
    if A.shape[-1] != G.n:
        raise ValueError(f"group axis {A.shape[-1]} != |G|={G.n}")
    L, M = A.shape[-3], A.shape[-2]
    K = min(L, M)
    F, Finv = G.F, G.Finv

    Ah = torch.einsum("...lmk,kn->...lmn", A.to(F.dtype), F)

    if G.is_abelian:
        U_h_F, S_h_F, Vh_h_F = _per_bin_svd(Ah, K)
    else:
        U_h_F, S_h_F, Vh_h_F = _per_irrep_svd(Ah, G.irrep_dims, L, M, K)

    # Inverse generalized Fourier transform along the n axis.
    # The output is real because A is real (the conjugate-symmetry of the
    # Fourier-domain blocks is preserved by the per-irrep / per-bin SVD).
    U = torch.einsum("...lmk,kn->...lmn", U_h_F, Finv).real.to(A.dtype)
    S = torch.einsum("...lmk,kn->...lmn", S_h_F, Finv).real.to(A.dtype)
    Vh = torch.einsum("...lmk,kn->...lmn", Vh_h_F, Finv).real.to(A.dtype)
    return U, S, Vh


def _per_bin_svd(Ah: torch.Tensor, K: int):
    """Abelian fast path: SVD per Fourier bin.

    Returns Fourier-domain (U_h, S_h, Vh_h) of shapes
    (..., L, K, n), (..., K, K, n), (..., K, M, n).
    """
    # Move the n axis to the front of the (L, M, n) trailing trio so that
    # torch.linalg.svd treats each bin as a separate matrix.
    Ah_p = Ah.movedim(-1, -3).contiguous()  # (..., n, L, M)
    U_p, S_p, Vh_p = torch.linalg.svd(Ah_p, full_matrices=False)
    # U_p: (..., n, L, K)   S_p: (..., n, K) (real)   Vh_p: (..., n, K, M)
    U_h = U_p.movedim(-3, -1).contiguous()           # (..., L, K, n)
    Vh_h = Vh_p.movedim(-3, -1).contiguous()         # (..., K, M, n)
    # Build the f-diagonal S: at each n-bin, S[k, k] = σ_k, off-diagonal 0.
    S_diag = torch.diag_embed(S_p.to(U_p.dtype))     # (..., n, K, K)
    S_h = S_diag.movedim(-3, -1).contiguous()        # (..., K, K, n)
    return U_h, S_h, Vh_h


def _per_irrep_svd(Ah: torch.Tensor, irrep_dims, L: int, M: int, K: int):
    """Per-irrep SVD + reassembly for non-abelian groups.

    Implements the construction in this module's docstring. Returns
    Fourier-domain (U_h, S_h, Vh_h) of shapes (..., L, K, n), (..., K, K, n),
    (..., K, M, n) such that the per-irrep ★_G block-matmul rule recovers
    Ah exactly (up to SVD precision) at full rank.
    """
    n = sum(d * d for d in irrep_dims)
    batch = tuple(Ah.shape[:-3])
    dtype = Ah.dtype
    device = Ah.device

    U_h = torch.zeros((*batch, L, K, n), dtype=dtype, device=device)
    S_h = torch.zeros((*batch, K, K, n), dtype=dtype, device=device)
    Vh_h = torch.zeros((*batch, K, M, n), dtype=dtype, device=device)

    cursor = 0
    for d in irrep_dims:
        block_size = d * d
        sl = slice(cursor, cursor + block_size)

        # Extract per-irrep block: (..., L, M, d²) -> (..., L, M, d, d)
        Ar = Ah[..., sl].reshape(*batch, L, M, d, d)
        # Group as a (L*d, M*d) matrix: rows (l, i), cols (m, j).
        # Permute (..., L, M, d, d) -> (..., L, d, M, d), then reshape
        # last 4 dims into (L*d, M*d).
        Ar_perm = Ar.permute(*range(len(batch)), -4, -2, -3, -1).contiguous()
        Ar_mat = Ar_perm.reshape(*batch, L * d, M * d)

        # Per-irrep SVD: K_ρ = K * d singular values
        U_rho, S_rho, Vh_rho = torch.linalg.svd(Ar_mat, full_matrices=False)
        # Shapes: U_rho (..., L*d, K*d), S_rho (..., K*d) real, Vh_rho (..., K*d, M*d)

        # ----- U reassembly -----
        # (L*d, K*d) -> (L, d, K, d) -> permute -> (L, K, d, d) -> flat (L, K, d²)
        U_rho_r = U_rho.reshape(*batch, L, d, K, d)
        U_rho_p = U_rho_r.permute(
            *range(len(batch)), -4, -2, -3, -1
        ).contiguous()  # (..., L, K, d, d)
        U_rho_flat = U_rho_p.reshape(*batch, L, K, block_size)
        U_h[..., :, :, sl] = U_rho_flat

        # ----- Vh reassembly -----
        # (K*d, M*d) -> (K, d, M, d) -> permute -> (K, M, d, d) -> flat (K, M, d²)
        Vh_rho_r = Vh_rho.reshape(*batch, K, d, M, d)
        Vh_rho_p = Vh_rho_r.permute(
            *range(len(batch)), -4, -2, -3, -1
        ).contiguous()  # (..., K, M, d, d)
        Vh_rho_flat = Vh_rho_p.reshape(*batch, K, M, block_size)
        Vh_h[..., :, :, sl] = Vh_rho_flat

        # ----- S reassembly -----
        # σ_ρ has shape (..., K*d). View as (..., K, d): σ[k, i] = σ_ρ[k*d + i].
        # The f-diagonal S_block at (k, k', i, j) is σ[k, i] · δ_{k=k'} · δ_{i=j}.
        S_rho_kd = S_rho.reshape(*batch, K, d).to(dtype)  # (..., K, d)
        # diag_embed turns the trailing K-dim into (K, K), ditto for d:
        # first place σ on the (i, j) diagonal: (..., K, d) -> (..., K, d, d)
        S_diag_d = torch.diag_embed(S_rho_kd)            # (..., K, d, d)
        # then place the (k, k')-diagonal: insert δ_{k=k'} via diag_embed on
        # axis -3 (the K axis). Use the broadcast pattern σ[k, i, j] -> σ[k, k', i, j]
        # by creating a (K, K, d, d) tensor with σ[k, i, j] on the diagonal in
        # the (k, k') plane.
        K_eye = torch.eye(K, device=device, dtype=dtype)
        S_block = K_eye[..., :, :, None, None] * S_diag_d[..., :, None, :, :]
        # S_block has shape (..., K, K, d, d) with the right f-diagonal pattern:
        #   S_block[..., k, k', i, j] = δ_{k=k'} · σ[k, i] · δ_{i=j}
        # Sanity: contracting K_eye[k,k'] * S_diag_d[k, i, j] indeed gives this.
        S_block_flat = S_block.reshape(*batch, K, K, block_size)
        S_h[..., :, :, sl] = S_block_flat

        cursor += block_size

    return U_h, S_h, Vh_h


def starg_truncate(A: torch.Tensor, G: "GroupAlgebra", k: int) -> torch.Tensor:
    """Rank-k truncation: zero out singular tubes beyond k, then reconstruct."""
    U, S, Vh = starg_svd(A, G)
    # Rank-by-tube via Frobenius norm of each diagonal tube of S
    tube_norms = S.diagonal(dim1=-3, dim2=-2).norm(dim=-1)  # (..., K)
    _, idx = torch.topk(tube_norms, k, dim=-1)
    mask = torch.zeros_like(tube_norms)
    mask.scatter_(-1, idx, 1.0)
    eye = torch.eye(S.shape[-3], device=S.device)
    S_trunc = S * (mask.unsqueeze(-1).unsqueeze(-1) * eye.unsqueeze(-1))
    from .product import starg_product
    return starg_product(starg_product(U, S_trunc, G), Vh, G)