"""Batched ★_G-SVD on GPU. ★_G-SVD (manuscript SI): Generalized-Fourier-domain SVD, compute a standard matrix SVD per irrep block and transform back. For abelian groups every block is 1×1, so this reduces to torch.linalg.svd along the group axis after the generalized Fourier transform. For non-abelian groups (e.g.\ the chiral octahedral group used in the Wigner--Eckart experiment) we do per- irrep block SVD and reassemble the result into the standard (U, S, Vh) Fourier-domain layout that ``starg_product`` consumes. Derivation. Let A have shape (..., L, M, n) with n = |G|. Its generalized Fourier transform Ah(:,:,k) re-indexes the n axis by ``F``: each k slice along n is one entry of one irrep matrix, with ``irrep_dims`` and row-vectorization fixing the layout. For irrep ρ of dimension d_ρ, the d_ρ² slices at positions cursor:cursor+d_ρ² collectively form the (L*d_ρ, M*d_ρ) "per-irrep block" matrix Â_ρ. SVD gives Â_ρ = U_ρ diag(σ_ρ) V_ρ^H with K_ρ = min(L*d_ρ, M*d_ρ) = K * d_ρ singular values per irrep, where K = min(L, M). Reassembly. Reshape U_ρ ∈ (L*d_ρ, K*d_ρ) by splitting the row index as (l, i) ∈ L × d_ρ and the column index as (k, p) ∈ K × d_ρ. Place the (l, i, k, p)-cell at U_h[l, k, cursor + i*d_ρ + p]. The same shape contract recovers the abelian path when d_ρ ≡ 1: U_h reduces to (..., L, K, n) with no inner-block structure. Identical reshapes apply to Vh; S becomes f-diagonal in the sense that ``S_h[k, k, cursor + i*d_ρ + i] = σ_{k, i}``, all other entries zero. ``U ★_G S ★_G Vh = A`` is then exact (up to numerical SVD precision) for any group, abelian or not, by construction, the per-irrep block-matmul rule of ``_block_diag_matmul`` in ``product.py`` is the inverse of this construction. """ from __future__ import annotations from typing import TYPE_CHECKING, Tuple import torch if TYPE_CHECKING: from .algebra import GroupAlgebra def starg_svd( A: torch.Tensor, G: "GroupAlgebra", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """★_G-SVD of a batched tensor. Args: A : tensor of shape (..., L, M, n) where n = |G|. G : GroupAlgebra carrying F, Finv, irrep_dims, is_abelian. Returns: U : (..., L, K, n) S : (..., K, K, n), f-diagonal in the sense above. Vh : (..., K, M, n) with K = min(L, M). The product ``U ★_G S ★_G Vh`` reconstructs A. """ if A.shape[-1] != G.n: raise ValueError(f"group axis {A.shape[-1]} != |G|={G.n}") L, M = A.shape[-3], A.shape[-2] K = min(L, M) F, Finv = G.F, G.Finv Ah = torch.einsum("...lmk,kn->...lmn", A.to(F.dtype), F) if G.is_abelian: U_h_F, S_h_F, Vh_h_F = _per_bin_svd(Ah, K) else: U_h_F, S_h_F, Vh_h_F = _per_irrep_svd(Ah, G.irrep_dims, L, M, K) # Inverse generalized Fourier transform along the n axis. # The output is real because A is real (the conjugate-symmetry of the # Fourier-domain blocks is preserved by the per-irrep / per-bin SVD). U = torch.einsum("...lmk,kn->...lmn", U_h_F, Finv).real.to(A.dtype) S = torch.einsum("...lmk,kn->...lmn", S_h_F, Finv).real.to(A.dtype) Vh = torch.einsum("...lmk,kn->...lmn", Vh_h_F, Finv).real.to(A.dtype) return U, S, Vh def _per_bin_svd(Ah: torch.Tensor, K: int): """Abelian fast path: SVD per Fourier bin. Returns Fourier-domain (U_h, S_h, Vh_h) of shapes (..., L, K, n), (..., K, K, n), (..., K, M, n). """ # Move the n axis to the front of the (L, M, n) trailing trio so that # torch.linalg.svd treats each bin as a separate matrix. Ah_p = Ah.movedim(-1, -3).contiguous() # (..., n, L, M) U_p, S_p, Vh_p = torch.linalg.svd(Ah_p, full_matrices=False) # U_p: (..., n, L, K) S_p: (..., n, K) (real) Vh_p: (..., n, K, M) U_h = U_p.movedim(-3, -1).contiguous() # (..., L, K, n) Vh_h = Vh_p.movedim(-3, -1).contiguous() # (..., K, M, n) # Build the f-diagonal S: at each n-bin, S[k, k] = σ_k, off-diagonal 0. S_diag = torch.diag_embed(S_p.to(U_p.dtype)) # (..., n, K, K) S_h = S_diag.movedim(-3, -1).contiguous() # (..., K, K, n) return U_h, S_h, Vh_h def _per_irrep_svd(Ah: torch.Tensor, irrep_dims, L: int, M: int, K: int): """Per-irrep SVD + reassembly for non-abelian groups. Implements the construction in this module's docstring. Returns Fourier-domain (U_h, S_h, Vh_h) of shapes (..., L, K, n), (..., K, K, n), (..., K, M, n) such that the per-irrep ★_G block-matmul rule recovers Ah exactly (up to SVD precision) at full rank. """ n = sum(d * d for d in irrep_dims) batch = tuple(Ah.shape[:-3]) dtype = Ah.dtype device = Ah.device U_h = torch.zeros((*batch, L, K, n), dtype=dtype, device=device) S_h = torch.zeros((*batch, K, K, n), dtype=dtype, device=device) Vh_h = torch.zeros((*batch, K, M, n), dtype=dtype, device=device) cursor = 0 for d in irrep_dims: block_size = d * d sl = slice(cursor, cursor + block_size) # Extract per-irrep block: (..., L, M, d²) -> (..., L, M, d, d) Ar = Ah[..., sl].reshape(*batch, L, M, d, d) # Group as a (L*d, M*d) matrix: rows (l, i), cols (m, j). # Permute (..., L, M, d, d) -> (..., L, d, M, d), then reshape # last 4 dims into (L*d, M*d). Ar_perm = Ar.permute(*range(len(batch)), -4, -2, -3, -1).contiguous() Ar_mat = Ar_perm.reshape(*batch, L * d, M * d) # Per-irrep SVD: K_ρ = K * d singular values U_rho, S_rho, Vh_rho = torch.linalg.svd(Ar_mat, full_matrices=False) # Shapes: U_rho (..., L*d, K*d), S_rho (..., K*d) real, Vh_rho (..., K*d, M*d) # ----- U reassembly ----- # (L*d, K*d) -> (L, d, K, d) -> permute -> (L, K, d, d) -> flat (L, K, d²) U_rho_r = U_rho.reshape(*batch, L, d, K, d) U_rho_p = U_rho_r.permute( *range(len(batch)), -4, -2, -3, -1 ).contiguous() # (..., L, K, d, d) U_rho_flat = U_rho_p.reshape(*batch, L, K, block_size) U_h[..., :, :, sl] = U_rho_flat # ----- Vh reassembly ----- # (K*d, M*d) -> (K, d, M, d) -> permute -> (K, M, d, d) -> flat (K, M, d²) Vh_rho_r = Vh_rho.reshape(*batch, K, d, M, d) Vh_rho_p = Vh_rho_r.permute( *range(len(batch)), -4, -2, -3, -1 ).contiguous() # (..., K, M, d, d) Vh_rho_flat = Vh_rho_p.reshape(*batch, K, M, block_size) Vh_h[..., :, :, sl] = Vh_rho_flat # ----- S reassembly ----- # σ_ρ has shape (..., K*d). View as (..., K, d): σ[k, i] = σ_ρ[k*d + i]. # The f-diagonal S_block at (k, k', i, j) is σ[k, i] · δ_{k=k'} · δ_{i=j}. S_rho_kd = S_rho.reshape(*batch, K, d).to(dtype) # (..., K, d) # diag_embed turns the trailing K-dim into (K, K), ditto for d: # first place σ on the (i, j) diagonal: (..., K, d) -> (..., K, d, d) S_diag_d = torch.diag_embed(S_rho_kd) # (..., K, d, d) # then place the (k, k')-diagonal: insert δ_{k=k'} via diag_embed on # axis -3 (the K axis). Use the broadcast pattern σ[k, i, j] -> σ[k, k', i, j] # by creating a (K, K, d, d) tensor with σ[k, i, j] on the diagonal in # the (k, k') plane. K_eye = torch.eye(K, device=device, dtype=dtype) S_block = K_eye[..., :, :, None, None] * S_diag_d[..., :, None, :, :] # S_block has shape (..., K, K, d, d) with the right f-diagonal pattern: # S_block[..., k, k', i, j] = δ_{k=k'} · σ[k, i] · δ_{i=j} # Sanity: contracting K_eye[k,k'] * S_diag_d[k, i, j] indeed gives this. S_block_flat = S_block.reshape(*batch, K, K, block_size) S_h[..., :, :, sl] = S_block_flat cursor += block_size return U_h, S_h, Vh_h def starg_truncate(A: torch.Tensor, G: "GroupAlgebra", k: int) -> torch.Tensor: """Rank-k truncation: zero out singular tubes beyond k, then reconstruct.""" U, S, Vh = starg_svd(A, G) # Rank-by-tube via Frobenius norm of each diagonal tube of S tube_norms = S.diagonal(dim1=-3, dim2=-2).norm(dim=-1) # (..., K) _, idx = torch.topk(tube_norms, k, dim=-1) mask = torch.zeros_like(tube_norms) mask.scatter_(-1, idx, 1.0) eye = torch.eye(S.shape[-3], device=S.device) S_trunc = S * (mask.unsqueeze(-1).unsqueeze(-1) * eye.unsqueeze(-1)) from .product import starg_product return starg_product(starg_product(U, S_trunc, G), Vh, G)