tensor-group-sym / python / large_scale / starg_torch / test_svd.py
test_svd.py
Raw
"""Reconstruction-property test for ``starg_svd``.

Verifies ``U ★_G S ★_G Vh ≈ A`` for both the abelian fast path
(cyclic Z_12) and the non-abelian per-irrep block path (chiral
octahedral group). Run as a script:

    python -m starg_torch.test_svd

Pass criterion: max element-wise reconstruction error < 1e-5 in
float32 (the SVD is computed in complex64 internally).
"""

from __future__ import annotations

import sys

import torch

from .algebra import GroupAlgebra
from .product import starg_product
from .svd import starg_svd


def _check(name: str, group: GroupAlgebra, L: int, M: int, atol: float):
    torch.manual_seed(0)
    A = torch.randn(L, M, group.n, dtype=torch.float32)

    U, S, Vh = starg_svd(A, group)
    A_reco = starg_product(starg_product(U, S, group), Vh, group)
    err = (A - A_reco).abs().max().item()
    fnorm = A.norm().item()
    rel = err / max(fnorm, 1e-12)
    status = "PASS" if err < atol else "FAIL"
    print(
        f"[{status}] {name}: |G|={group.n} (L,M)=({L},{M}) "
        f"max_err={err:.3e}  rel={rel:.3e}  threshold={atol:.0e}"
    )
    return err < atol


def main():
    ok = True

    # 1. Abelian cyclic Z_12, fast path (per-Fourier-bin SVD)
    G_cyc = GroupAlgebra("cyclic", 12)
    ok &= _check("cyclic Z_12 (abelian)", G_cyc, L=8, M=10, atol=1e-4)

    # 2. Non-abelian chiral octahedral group, per-irrep block path
    G_oct = GroupAlgebra("octahedral")
    ok &= _check("chiral octahedral O (non-abelian)", G_oct, L=6, M=9, atol=1e-3)

    # 3. Cyclic with non-square L=M=K (reconstruction at full rank)
    ok &= _check("cyclic Z_8 square 6x6", GroupAlgebra("cyclic", 8), L=6, M=6, atol=1e-4)

    # 4. Octahedral square (full rank)
    ok &= _check("octahedral square 8x8", G_oct, L=8, M=8, atol=1e-3)

    print()
    print("OVERALL:", "PASS" if ok else "FAIL")
    sys.exit(0 if ok else 1)


if __name__ == "__main__":
    main()