tensor-group-sym / python / large_scale / starg_torch / features.py
features.py
Raw
"""Torch port of the ★_G invariant feature extractor (extractStarGFeatures).

Vectorized implementation of the 7 invariant feature blocks (DC, AC, per-
frequency power, per-row generalized Fourier power, ★_G-SVD tube norms,
direct invariants, spectral statistics) used in every linear and Neural-★_G
experiment in the paper. Designed to run on a single GPU for batches up to
the full QM9 dataset (134k molecules).
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, TYPE_CHECKING

import numpy as np
import torch

if TYPE_CHECKING:
    from .algebra import GroupAlgebra


@dataclass
class FeatureNormParams:
    p: int
    q: int
    n_svd: int
    inv_mask: torch.Tensor
    eq_idx: torch.Tensor
    K_rows: int
    keep: torch.Tensor
    mu: torch.Tensor
    sig: torch.Tensor


def _best_reshape(n_f: int) -> tuple[int, int]:
    best_min, p, q = 0, n_f, 1
    for pp in range(2, int(np.sqrt(n_f * 2)) + 1):
        qq = n_f // pp
        if qq < 1:
            continue
        if min(pp, qq) > best_min:
            best_min, p, q = min(pp, qq), pp, qq
    return p, q


def extract_starg_features(
    X: torch.Tensor,
    G: "GroupAlgebra",
    norm: Optional[FeatureNormParams] = None,
    K_rows_default: int = 14,
) -> tuple[torch.Tensor, FeatureNormParams]:
    """Extract invariant features from a batch of tensors.

    X has shape (N, n_f, |G|). Returns (features, norm_params).
    """
    N, n_f, n_g = X.shape
    if n_g != G.n:
        raise ValueError(f"X group axis {n_g} != |G|={G.n}")
    device = X.device
    is_training = norm is None

    if is_training:
        p, q = _best_reshape(n_f)
        n_svd = min(p, q)
        # Identify rows constant under the group action (sample 0 as a proxy)
        row_var = X[0].var(dim=1)  # (n_f,)
        inv_mask = row_var < 1e-8 * (row_var.max() + 1e-20)
        eq_idx = (~inv_mask).nonzero(as_tuple=False).flatten()
        K_rows = min(K_rows_default, eq_idx.numel())
    else:
        p, q, n_svd = norm.p, norm.q, norm.n_svd
        inv_mask, eq_idx, K_rows = norm.inv_mask, norm.eq_idx, norm.K_rows

    # (a) DC component
    dc = X.mean(dim=2)  # (N, n_f)
    # (b) AC energy (std across group)
    ac = X.std(dim=2)
    # (c) + (d) Generalized Fourier transform along group dim
    F = G.F.to(X.dtype if X.is_complex() else torch.complex64).to(device)
    Xh = torch.einsum("nfk,kj->nfj", X.to(F.dtype), F)
    col_power = (Xh.abs() ** 2).sum(dim=1).real.to(X.dtype)
    # per-row generalized Fourier power for first K equivariant rows
    row_power = (Xh[:, eq_idx[:K_rows], :].abs() ** 2).real.to(X.dtype)
    row_power = row_power.reshape(N, K_rows * n_g)
    # (e) ★_G-SVD tube norms
    tube_norms = _starg_svd_tube_norms(X, G, p, q, n_svd)
    # (f) direct invariants (rows of X identified as invariant)
    if inv_mask.any():
        inv_feat = X[:, inv_mask, 0]
    else:
        inv_feat = torch.zeros(N, 0, device=device, dtype=X.dtype)
    # (g) spectral statistics of unfolded matrix
    stats = _spectral_stats(X)

    feat = torch.cat([dc, ac, col_power, row_power, tube_norms, inv_feat, stats], dim=1)
    feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)

    if is_training:
        std = feat.std(dim=0)
        keep = std >= 1e-8
        mu = feat[:, keep].mean(dim=0)
        sig = feat[:, keep].std(dim=0)
        sig = torch.where(sig < 1e-10, torch.ones_like(sig), sig)
        norm = FeatureNormParams(p, q, n_svd, inv_mask, eq_idx, K_rows, keep, mu, sig)

    feat = (feat[:, norm.keep] - norm.mu) / norm.sig
    intercept = torch.ones(N, 1, device=device, dtype=feat.dtype)
    return torch.cat([intercept, feat], dim=1), norm


def _starg_svd_tube_norms(X: torch.Tensor, G, p: int, q: int, n_svd: int) -> torch.Tensor:
    """Compute Frobenius norms of singular tubes for the (p, q, |G|) reshape."""
    N, n_f, n_g = X.shape
    pad = max(0, p * q - n_f)
    if pad > 0:
        X_padded = torch.cat([X, torch.zeros(N, pad, n_g, device=X.device, dtype=X.dtype)], dim=1)
    else:
        X_padded = X[:, : p * q, :]
    Xt = X_padded.reshape(N, p, q, n_g)
    F = G.F.to(torch.complex64).to(X.device)
    Xh = torch.einsum("npqk,kj->npqj", Xt.to(F.dtype), F)
    # Per-frequency SVD on (p, q) matrices
    Xh_perm = Xh.permute(0, 3, 1, 2)  # (N, n_g, p, q)
    s = torch.linalg.svdvals(Xh_perm)  # (N, n_g, min(p, q))
    s = s[..., :n_svd]
    # Tube norm = ||s(k, k, :)||_F  in the *original* domain.
    # By Parseval, ||s_orig||_F = ||s_hat||_F up to the 1/√|G| factor,
    # so we use the per-singular-tube Fourier norm directly.
    tube_norms = s.norm(dim=1) / np.sqrt(n_g)  # (N, n_svd)
    tube_norms, _ = tube_norms.sort(dim=1, descending=True)
    return tube_norms.real.to(X.dtype)


def _spectral_stats(X: torch.Tensor) -> torch.Tensor:
    """Nuclear / spectral norm, condition number, and entropy of singular values."""
    N, n_f, n_g = X.shape
    Xm = X.reshape(N, n_f, n_g).to(torch.float32)
    s = torch.linalg.svdvals(Xm)  # (N, min)
    nuclear = s.sum(dim=1)
    spectral = s[:, 0]
    s_clip = torch.clamp(s, min=1e-10)
    cond = s[:, 0] / s_clip[:, -1]
    p = s_clip / (s_clip.sum(dim=1, keepdim=True) + 1e-20)
    entropy = -(p * (p + 1e-20).log()).sum(dim=1)
    return torch.stack([nuclear, spectral, cond, entropy], dim=1).to(X.dtype)