"""Torch port of the ★_G invariant feature extractor (extractStarGFeatures). Vectorized implementation of the 7 invariant feature blocks (DC, AC, per- frequency power, per-row generalized Fourier power, ★_G-SVD tube norms, direct invariants, spectral statistics) used in every linear and Neural-★_G experiment in the paper. Designed to run on a single GPU for batches up to the full QM9 dataset (134k molecules). """ from __future__ import annotations from dataclasses import dataclass from typing import Optional, TYPE_CHECKING import numpy as np import torch if TYPE_CHECKING: from .algebra import GroupAlgebra @dataclass class FeatureNormParams: p: int q: int n_svd: int inv_mask: torch.Tensor eq_idx: torch.Tensor K_rows: int keep: torch.Tensor mu: torch.Tensor sig: torch.Tensor def _best_reshape(n_f: int) -> tuple[int, int]: best_min, p, q = 0, n_f, 1 for pp in range(2, int(np.sqrt(n_f * 2)) + 1): qq = n_f // pp if qq < 1: continue if min(pp, qq) > best_min: best_min, p, q = min(pp, qq), pp, qq return p, q def extract_starg_features( X: torch.Tensor, G: "GroupAlgebra", norm: Optional[FeatureNormParams] = None, K_rows_default: int = 14, ) -> tuple[torch.Tensor, FeatureNormParams]: """Extract invariant features from a batch of tensors. X has shape (N, n_f, |G|). Returns (features, norm_params). """ N, n_f, n_g = X.shape if n_g != G.n: raise ValueError(f"X group axis {n_g} != |G|={G.n}") device = X.device is_training = norm is None if is_training: p, q = _best_reshape(n_f) n_svd = min(p, q) # Identify rows constant under the group action (sample 0 as a proxy) row_var = X[0].var(dim=1) # (n_f,) inv_mask = row_var < 1e-8 * (row_var.max() + 1e-20) eq_idx = (~inv_mask).nonzero(as_tuple=False).flatten() K_rows = min(K_rows_default, eq_idx.numel()) else: p, q, n_svd = norm.p, norm.q, norm.n_svd inv_mask, eq_idx, K_rows = norm.inv_mask, norm.eq_idx, norm.K_rows # (a) DC component dc = X.mean(dim=2) # (N, n_f) # (b) AC energy (std across group) ac = X.std(dim=2) # (c) + (d) Generalized Fourier transform along group dim F = G.F.to(X.dtype if X.is_complex() else torch.complex64).to(device) Xh = torch.einsum("nfk,kj->nfj", X.to(F.dtype), F) col_power = (Xh.abs() ** 2).sum(dim=1).real.to(X.dtype) # per-row generalized Fourier power for first K equivariant rows row_power = (Xh[:, eq_idx[:K_rows], :].abs() ** 2).real.to(X.dtype) row_power = row_power.reshape(N, K_rows * n_g) # (e) ★_G-SVD tube norms tube_norms = _starg_svd_tube_norms(X, G, p, q, n_svd) # (f) direct invariants (rows of X identified as invariant) if inv_mask.any(): inv_feat = X[:, inv_mask, 0] else: inv_feat = torch.zeros(N, 0, device=device, dtype=X.dtype) # (g) spectral statistics of unfolded matrix stats = _spectral_stats(X) feat = torch.cat([dc, ac, col_power, row_power, tube_norms, inv_feat, stats], dim=1) feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0) if is_training: std = feat.std(dim=0) keep = std >= 1e-8 mu = feat[:, keep].mean(dim=0) sig = feat[:, keep].std(dim=0) sig = torch.where(sig < 1e-10, torch.ones_like(sig), sig) norm = FeatureNormParams(p, q, n_svd, inv_mask, eq_idx, K_rows, keep, mu, sig) feat = (feat[:, norm.keep] - norm.mu) / norm.sig intercept = torch.ones(N, 1, device=device, dtype=feat.dtype) return torch.cat([intercept, feat], dim=1), norm def _starg_svd_tube_norms(X: torch.Tensor, G, p: int, q: int, n_svd: int) -> torch.Tensor: """Compute Frobenius norms of singular tubes for the (p, q, |G|) reshape.""" N, n_f, n_g = X.shape pad = max(0, p * q - n_f) if pad > 0: X_padded = torch.cat([X, torch.zeros(N, pad, n_g, device=X.device, dtype=X.dtype)], dim=1) else: X_padded = X[:, : p * q, :] Xt = X_padded.reshape(N, p, q, n_g) F = G.F.to(torch.complex64).to(X.device) Xh = torch.einsum("npqk,kj->npqj", Xt.to(F.dtype), F) # Per-frequency SVD on (p, q) matrices Xh_perm = Xh.permute(0, 3, 1, 2) # (N, n_g, p, q) s = torch.linalg.svdvals(Xh_perm) # (N, n_g, min(p, q)) s = s[..., :n_svd] # Tube norm = ||s(k, k, :)||_F in the *original* domain. # By Parseval, ||s_orig||_F = ||s_hat||_F up to the 1/√|G| factor, # so we use the per-singular-tube Fourier norm directly. tube_norms = s.norm(dim=1) / np.sqrt(n_g) # (N, n_svd) tube_norms, _ = tube_norms.sort(dim=1, descending=True) return tube_norms.real.to(X.dtype) def _spectral_stats(X: torch.Tensor) -> torch.Tensor: """Nuclear / spectral norm, condition number, and entropy of singular values.""" N, n_f, n_g = X.shape Xm = X.reshape(N, n_f, n_g).to(torch.float32) s = torch.linalg.svdvals(Xm) # (N, min) nuclear = s.sum(dim=1) spectral = s[:, 0] s_clip = torch.clamp(s, min=1e-10) cond = s[:, 0] / s_clip[:, -1] p = s_clip / (s_clip.sum(dim=1, keepdim=True) + 1e-20) entropy = -(p * (p + 1e-20).log()).sum(dim=1) return torch.stack([nuclear, spectral, cond, entropy], dim=1).to(X.dtype)