"""Batched ★_G-SVD on GPU.
★_G-SVD (manuscript SI): Generalized-Fourier-domain SVD, compute a standard
matrix SVD per irrep block and transform back. For abelian groups every
block is 1×1, so this reduces to torch.linalg.svd along the group axis
after the generalized Fourier transform. For non-abelian groups (e.g.\ the
chiral octahedral group used in the Wigner--Eckart experiment) we do per-
irrep block SVD and reassemble the result into the standard
(U, S, Vh) Fourier-domain layout that ``starg_product`` consumes.
Derivation. Let A have shape (..., L, M, n) with n = |G|. Its generalized
Fourier transform Ah(:,:,k) re-indexes the n axis by ``F``: each k slice
along n is one entry of one irrep matrix, with ``irrep_dims`` and
row-vectorization fixing the layout. For irrep ρ of dimension d_ρ, the d_ρ²
slices at positions cursor:cursor+d_ρ² collectively form the (L*d_ρ, M*d_ρ)
"per-irrep block" matrix Â_ρ. SVD gives Â_ρ = U_ρ diag(σ_ρ) V_ρ^H with
K_ρ = min(L*d_ρ, M*d_ρ) = K * d_ρ singular values per irrep, where
K = min(L, M).
Reassembly. Reshape U_ρ ∈ (L*d_ρ, K*d_ρ) by splitting the row index as
(l, i) ∈ L × d_ρ and the column index as (k, p) ∈ K × d_ρ. Place the
(l, i, k, p)-cell at U_h[l, k, cursor + i*d_ρ + p]. The same shape
contract recovers the abelian path when d_ρ ≡ 1: U_h reduces to
(..., L, K, n) with no inner-block structure. Identical reshapes apply
to Vh; S becomes f-diagonal in the sense that
``S_h[k, k, cursor + i*d_ρ + i] = σ_{k, i}``, all other entries zero.
``U ★_G S ★_G Vh = A`` is then exact (up to numerical SVD precision)
for any group, abelian or not, by construction, the per-irrep
block-matmul rule of ``_block_diag_matmul`` in ``product.py`` is the
inverse of this construction.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Tuple
import torch
if TYPE_CHECKING:
from .algebra import GroupAlgebra
def starg_svd(
A: torch.Tensor,
G: "GroupAlgebra",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""★_G-SVD of a batched tensor.
Args:
A : tensor of shape (..., L, M, n) where n = |G|.
G : GroupAlgebra carrying F, Finv, irrep_dims, is_abelian.
Returns:
U : (..., L, K, n)
S : (..., K, K, n), f-diagonal in the sense above.
Vh : (..., K, M, n)
with K = min(L, M).
The product ``U ★_G S ★_G Vh`` reconstructs A.
"""
if A.shape[-1] != G.n:
raise ValueError(f"group axis {A.shape[-1]} != |G|={G.n}")
L, M = A.shape[-3], A.shape[-2]
K = min(L, M)
F, Finv = G.F, G.Finv
Ah = torch.einsum("...lmk,kn->...lmn", A.to(F.dtype), F)
if G.is_abelian:
U_h_F, S_h_F, Vh_h_F = _per_bin_svd(Ah, K)
else:
U_h_F, S_h_F, Vh_h_F = _per_irrep_svd(Ah, G.irrep_dims, L, M, K)
# Inverse generalized Fourier transform along the n axis.
# The output is real because A is real (the conjugate-symmetry of the
# Fourier-domain blocks is preserved by the per-irrep / per-bin SVD).
U = torch.einsum("...lmk,kn->...lmn", U_h_F, Finv).real.to(A.dtype)
S = torch.einsum("...lmk,kn->...lmn", S_h_F, Finv).real.to(A.dtype)
Vh = torch.einsum("...lmk,kn->...lmn", Vh_h_F, Finv).real.to(A.dtype)
return U, S, Vh
def _per_bin_svd(Ah: torch.Tensor, K: int):
"""Abelian fast path: SVD per Fourier bin.
Returns Fourier-domain (U_h, S_h, Vh_h) of shapes
(..., L, K, n), (..., K, K, n), (..., K, M, n).
"""
# Move the n axis to the front of the (L, M, n) trailing trio so that
# torch.linalg.svd treats each bin as a separate matrix.
Ah_p = Ah.movedim(-1, -3).contiguous() # (..., n, L, M)
U_p, S_p, Vh_p = torch.linalg.svd(Ah_p, full_matrices=False)
# U_p: (..., n, L, K) S_p: (..., n, K) (real) Vh_p: (..., n, K, M)
U_h = U_p.movedim(-3, -1).contiguous() # (..., L, K, n)
Vh_h = Vh_p.movedim(-3, -1).contiguous() # (..., K, M, n)
# Build the f-diagonal S: at each n-bin, S[k, k] = σ_k, off-diagonal 0.
S_diag = torch.diag_embed(S_p.to(U_p.dtype)) # (..., n, K, K)
S_h = S_diag.movedim(-3, -1).contiguous() # (..., K, K, n)
return U_h, S_h, Vh_h
def _per_irrep_svd(Ah: torch.Tensor, irrep_dims, L: int, M: int, K: int):
"""Per-irrep SVD + reassembly for non-abelian groups.
Implements the construction in this module's docstring. Returns
Fourier-domain (U_h, S_h, Vh_h) of shapes (..., L, K, n), (..., K, K, n),
(..., K, M, n) such that the per-irrep ★_G block-matmul rule recovers
Ah exactly (up to SVD precision) at full rank.
"""
n = sum(d * d for d in irrep_dims)
batch = tuple(Ah.shape[:-3])
dtype = Ah.dtype
device = Ah.device
U_h = torch.zeros((*batch, L, K, n), dtype=dtype, device=device)
S_h = torch.zeros((*batch, K, K, n), dtype=dtype, device=device)
Vh_h = torch.zeros((*batch, K, M, n), dtype=dtype, device=device)
cursor = 0
for d in irrep_dims:
block_size = d * d
sl = slice(cursor, cursor + block_size)
# Extract per-irrep block: (..., L, M, d²) -> (..., L, M, d, d)
Ar = Ah[..., sl].reshape(*batch, L, M, d, d)
# Group as a (L*d, M*d) matrix: rows (l, i), cols (m, j).
# Permute (..., L, M, d, d) -> (..., L, d, M, d), then reshape
# last 4 dims into (L*d, M*d).
Ar_perm = Ar.permute(*range(len(batch)), -4, -2, -3, -1).contiguous()
Ar_mat = Ar_perm.reshape(*batch, L * d, M * d)
# Per-irrep SVD: K_ρ = K * d singular values
U_rho, S_rho, Vh_rho = torch.linalg.svd(Ar_mat, full_matrices=False)
# Shapes: U_rho (..., L*d, K*d), S_rho (..., K*d) real, Vh_rho (..., K*d, M*d)
# ----- U reassembly -----
# (L*d, K*d) -> (L, d, K, d) -> permute -> (L, K, d, d) -> flat (L, K, d²)
U_rho_r = U_rho.reshape(*batch, L, d, K, d)
U_rho_p = U_rho_r.permute(
*range(len(batch)), -4, -2, -3, -1
).contiguous() # (..., L, K, d, d)
U_rho_flat = U_rho_p.reshape(*batch, L, K, block_size)
U_h[..., :, :, sl] = U_rho_flat
# ----- Vh reassembly -----
# (K*d, M*d) -> (K, d, M, d) -> permute -> (K, M, d, d) -> flat (K, M, d²)
Vh_rho_r = Vh_rho.reshape(*batch, K, d, M, d)
Vh_rho_p = Vh_rho_r.permute(
*range(len(batch)), -4, -2, -3, -1
).contiguous() # (..., K, M, d, d)
Vh_rho_flat = Vh_rho_p.reshape(*batch, K, M, block_size)
Vh_h[..., :, :, sl] = Vh_rho_flat
# ----- S reassembly -----
# σ_ρ has shape (..., K*d). View as (..., K, d): σ[k, i] = σ_ρ[k*d + i].
# The f-diagonal S_block at (k, k', i, j) is σ[k, i] · δ_{k=k'} · δ_{i=j}.
S_rho_kd = S_rho.reshape(*batch, K, d).to(dtype) # (..., K, d)
# diag_embed turns the trailing K-dim into (K, K), ditto for d:
# first place σ on the (i, j) diagonal: (..., K, d) -> (..., K, d, d)
S_diag_d = torch.diag_embed(S_rho_kd) # (..., K, d, d)
# then place the (k, k')-diagonal: insert δ_{k=k'} via diag_embed on
# axis -3 (the K axis). Use the broadcast pattern σ[k, i, j] -> σ[k, k', i, j]
# by creating a (K, K, d, d) tensor with σ[k, i, j] on the diagonal in
# the (k, k') plane.
K_eye = torch.eye(K, device=device, dtype=dtype)
S_block = K_eye[..., :, :, None, None] * S_diag_d[..., :, None, :, :]
# S_block has shape (..., K, K, d, d) with the right f-diagonal pattern:
# S_block[..., k, k', i, j] = δ_{k=k'} · σ[k, i] · δ_{i=j}
# Sanity: contracting K_eye[k,k'] * S_diag_d[k, i, j] indeed gives this.
S_block_flat = S_block.reshape(*batch, K, K, block_size)
S_h[..., :, :, sl] = S_block_flat
cursor += block_size
return U_h, S_h, Vh_h
def starg_truncate(A: torch.Tensor, G: "GroupAlgebra", k: int) -> torch.Tensor:
"""Rank-k truncation: zero out singular tubes beyond k, then reconstruct."""
U, S, Vh = starg_svd(A, G)
# Rank-by-tube via Frobenius norm of each diagonal tube of S
tube_norms = S.diagonal(dim1=-3, dim2=-2).norm(dim=-1) # (..., K)
_, idx = torch.topk(tube_norms, k, dim=-1)
mask = torch.zeros_like(tube_norms)
mask.scatter_(-1, idx, 1.0)
eye = torch.eye(S.shape[-3], device=S.device)
S_trunc = S * (mask.unsqueeze(-1).unsqueeze(-1) * eye.unsqueeze(-1))
from .product import starg_product
return starg_product(starg_product(U, S_trunc, G), Vh, G)