"""Torch port of the ★_G invariant feature extractor (extractStarGFeatures).
Vectorized implementation of the 7 invariant feature blocks (DC, AC, per-
frequency power, per-row generalized Fourier power, ★_G-SVD tube norms,
direct invariants, spectral statistics) used in every linear and Neural-★_G
experiment in the paper. Designed to run on a single GPU for batches up to
the full QM9 dataset (134k molecules).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from .algebra import GroupAlgebra
@dataclass
class FeatureNormParams:
p: int
q: int
n_svd: int
inv_mask: torch.Tensor
eq_idx: torch.Tensor
K_rows: int
keep: torch.Tensor
mu: torch.Tensor
sig: torch.Tensor
def _best_reshape(n_f: int) -> tuple[int, int]:
best_min, p, q = 0, n_f, 1
for pp in range(2, int(np.sqrt(n_f * 2)) + 1):
qq = n_f // pp
if qq < 1:
continue
if min(pp, qq) > best_min:
best_min, p, q = min(pp, qq), pp, qq
return p, q
def extract_starg_features(
X: torch.Tensor,
G: "GroupAlgebra",
norm: Optional[FeatureNormParams] = None,
K_rows_default: int = 14,
) -> tuple[torch.Tensor, FeatureNormParams]:
"""Extract invariant features from a batch of tensors.
X has shape (N, n_f, |G|). Returns (features, norm_params).
"""
N, n_f, n_g = X.shape
if n_g != G.n:
raise ValueError(f"X group axis {n_g} != |G|={G.n}")
device = X.device
is_training = norm is None
if is_training:
p, q = _best_reshape(n_f)
n_svd = min(p, q)
# Identify rows constant under the group action (sample 0 as a proxy)
row_var = X[0].var(dim=1) # (n_f,)
inv_mask = row_var < 1e-8 * (row_var.max() + 1e-20)
eq_idx = (~inv_mask).nonzero(as_tuple=False).flatten()
K_rows = min(K_rows_default, eq_idx.numel())
else:
p, q, n_svd = norm.p, norm.q, norm.n_svd
inv_mask, eq_idx, K_rows = norm.inv_mask, norm.eq_idx, norm.K_rows
# (a) DC component
dc = X.mean(dim=2) # (N, n_f)
# (b) AC energy (std across group)
ac = X.std(dim=2)
# (c) + (d) Generalized Fourier transform along group dim
F = G.F.to(X.dtype if X.is_complex() else torch.complex64).to(device)
Xh = torch.einsum("nfk,kj->nfj", X.to(F.dtype), F)
col_power = (Xh.abs() ** 2).sum(dim=1).real.to(X.dtype)
# per-row generalized Fourier power for first K equivariant rows
row_power = (Xh[:, eq_idx[:K_rows], :].abs() ** 2).real.to(X.dtype)
row_power = row_power.reshape(N, K_rows * n_g)
# (e) ★_G-SVD tube norms
tube_norms = _starg_svd_tube_norms(X, G, p, q, n_svd)
# (f) direct invariants (rows of X identified as invariant)
if inv_mask.any():
inv_feat = X[:, inv_mask, 0]
else:
inv_feat = torch.zeros(N, 0, device=device, dtype=X.dtype)
# (g) spectral statistics of unfolded matrix
stats = _spectral_stats(X)
feat = torch.cat([dc, ac, col_power, row_power, tube_norms, inv_feat, stats], dim=1)
feat = torch.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
if is_training:
std = feat.std(dim=0)
keep = std >= 1e-8
mu = feat[:, keep].mean(dim=0)
sig = feat[:, keep].std(dim=0)
sig = torch.where(sig < 1e-10, torch.ones_like(sig), sig)
norm = FeatureNormParams(p, q, n_svd, inv_mask, eq_idx, K_rows, keep, mu, sig)
feat = (feat[:, norm.keep] - norm.mu) / norm.sig
intercept = torch.ones(N, 1, device=device, dtype=feat.dtype)
return torch.cat([intercept, feat], dim=1), norm
def _starg_svd_tube_norms(X: torch.Tensor, G, p: int, q: int, n_svd: int) -> torch.Tensor:
"""Compute Frobenius norms of singular tubes for the (p, q, |G|) reshape."""
N, n_f, n_g = X.shape
pad = max(0, p * q - n_f)
if pad > 0:
X_padded = torch.cat([X, torch.zeros(N, pad, n_g, device=X.device, dtype=X.dtype)], dim=1)
else:
X_padded = X[:, : p * q, :]
Xt = X_padded.reshape(N, p, q, n_g)
F = G.F.to(torch.complex64).to(X.device)
Xh = torch.einsum("npqk,kj->npqj", Xt.to(F.dtype), F)
# Per-frequency SVD on (p, q) matrices
Xh_perm = Xh.permute(0, 3, 1, 2) # (N, n_g, p, q)
s = torch.linalg.svdvals(Xh_perm) # (N, n_g, min(p, q))
s = s[..., :n_svd]
# Tube norm = ||s(k, k, :)||_F in the *original* domain.
# By Parseval, ||s_orig||_F = ||s_hat||_F up to the 1/√|G| factor,
# so we use the per-singular-tube Fourier norm directly.
tube_norms = s.norm(dim=1) / np.sqrt(n_g) # (N, n_svd)
tube_norms, _ = tube_norms.sort(dim=1, descending=True)
return tube_norms.real.to(X.dtype)
def _spectral_stats(X: torch.Tensor) -> torch.Tensor:
"""Nuclear / spectral norm, condition number, and entropy of singular values."""
N, n_f, n_g = X.shape
Xm = X.reshape(N, n_f, n_g).to(torch.float32)
s = torch.linalg.svdvals(Xm) # (N, min)
nuclear = s.sum(dim=1)
spectral = s[:, 0]
s_clip = torch.clamp(s, min=1e-10)
cond = s[:, 0] / s_clip[:, -1]
p = s_clip / (s_clip.sum(dim=1, keepdim=True) + 1e-20)
entropy = -(p * (p + 1e-20).log()).sum(dim=1)
return torch.stack([nuclear, spectral, cond, entropy], dim=1).to(X.dtype)