"""Per-irrep R² decomposition.
For each (target, irrep ρ) cell, train a ridge regressor on *only* the
features that project onto irrep ρ, the per-irrep Fourier power
features produced by the generalized Fourier transform, and report the
resulting test R². The output is a target × irrep table:
| | A1 | A2 | E | T1 | T2 |
|-------|------|------|------|------|------|
| gap | 0.62 | 0.62 | 0.30 | 0.40 | 0.21 |
| mu_x | 0.01 | 0.01 | 0.01 | 0.04 | 0.00 |
| ...
This decomposition is uniquely available to the ★_G framework because
the per-irrep features are produced by the generalized Fourier transform
of the molecule tensor over the chosen group. ENNs can be probed for
similar information via attention or activation analysis, but they do
not produce a closed-form per-irrep R², the readout is end-to-end.
Usage:
python per_irrep_audit.py --qm9_dir ~/data/qm9/dsgdb9nsd \
--group octahedral --targets gap,alpha,mu,zpve --seed 0 \
--out_csv results/per_irrep_audit.csv
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
from sklearn.linear_model import RidgeCV
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.featurizers import octahedral_features, cyclic_angular_features
from starg_torch.algebra import GroupAlgebra
from starg_torch.octahedral import octahedral_rotations, octahedral_irreps
# Octahedral irreps in the order assembled by octahedral_irreps()
OCT_IRREP_NAMES = ["A1", "A2", "E", "T1", "T2"]
OCT_IRREP_DIMS = [1, 1, 2, 3, 3]
def _per_irrep_power_features(X: np.ndarray, irrep_dims=OCT_IRREP_DIMS,
) -> Dict[str, np.ndarray]:
"""Given the (N, n_feat, |G|) molecule tensor, return per-irrep power
features: for each irrep ρ of dim d, slice the d² Fourier-block columns
out of the F_G transform and report the squared Frobenius power per
feature row, yielding shape (N, n_feat) per irrep.
The columns of the generalized Fourier matrix F_G are arranged by
irrep block per `octahedral_irreps()`; each irrep ρ contributes d²
consecutive columns.
"""
from starg_torch.octahedral import octahedral_irreps
F, _ = octahedral_irreps()
F = np.asarray(F, dtype=np.complex128)
# Project: shape (N, n_feat, |G|) @ F gives (N, n_feat, |G|) in Fourier coords
Xh = np.einsum("nfk,kj->nfj", X.astype(np.complex128), F)
out = {}
cursor = 0
for name, d in zip(OCT_IRREP_NAMES, irrep_dims):
block_size = d * d
block = Xh[:, :, cursor : cursor + block_size]
# Squared Frobenius power per (sample, feature row)
power = (np.abs(block) ** 2).sum(axis=2)
out[name] = power.real.astype(np.float32)
cursor += block_size
return out
def _stack_features(samples) -> np.ndarray:
feats = [octahedral_features(s) for s in samples]
return np.stack(feats, axis=0)
def _build_target(samples, target: str) -> np.ndarray:
if target in PROPERTY_INDEX:
return np.array([s.properties[PROPERTY_INDEX[target]] for s in samples])
raise NotImplementedError(
f"per-irrep audit currently only supports scalar QM9 targets; "
f"got '{target}'."
)
def _ridge_test_r2(X_tr: np.ndarray, y_tr: np.ndarray,
X_va: np.ndarray, y_va: np.ndarray,
X_te: np.ndarray, y_te: np.ndarray) -> float:
# Standardize on training fold
mu = X_tr.mean(axis=0)
sig = X_tr.std(axis=0) + 1e-8
X_tr = (X_tr - mu) / sig
X_va = (X_va - mu) / sig
X_te = (X_te - mu) / sig
model = RidgeCV(alphas=np.logspace(-3, 3, 7))
model.fit(np.vstack([X_tr, X_va]), np.concatenate([y_tr, y_va]))
y_pred = model.predict(X_te)
ss_res = float(np.sum((y_te - y_pred) ** 2))
ss_tot = float(np.sum((y_te - y_te.mean()) ** 2))
return 1.0 - ss_res / max(ss_tot, 1e-12)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--qm9_dir", required=True)
ap.add_argument("--group", default="octahedral",
choices=("octahedral",),
help="per-irrep audit currently wired for octahedral; "
"cyclic only has 1-d irreps so the table is trivial.")
ap.add_argument("--targets", default="gap,alpha,mu,zpve")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--max_molecules", type=int, default=None)
ap.add_argument("--out_csv", default="results/per_irrep_audit.csv")
args = ap.parse_args()
print(f"[load] QM9 from {args.qm9_dir}")
ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
samples = [ds[i] for i in range(len(ds))]
print(f"[load] {len(samples)} molecules")
print(f"[feat] computing octahedral features (n_g=24)")
X = _stack_features(samples) # (N, 14, 24)
print(f"[feat] X shape: {X.shape}")
print("[feat] decomposing per-irrep power features")
per_irrep = _per_irrep_power_features(X) # dict name → (N, 14)
train_idx, val_idx, test_idx = qm9_split(len(samples), seed=args.seed)
targets = args.targets.split(",")
rows = [("target",) + tuple(OCT_IRREP_NAMES) + ("full_concat",)]
for target in targets:
y = _build_target(samples, target)
y_tr = y[train_idx]; y_va = y[val_idx]; y_te = y[test_idx]
cells: List[float] = []
for name in OCT_IRREP_NAMES:
P = per_irrep[name]
r2 = _ridge_test_r2(P[train_idx], y_tr,
P[val_idx], y_va,
P[test_idx], y_te)
cells.append(r2)
print(f" [{target} / {name}] R² = {r2:.4f}")
# All-irreps concatenated as the upper bound for context
full = np.concatenate([per_irrep[n] for n in OCT_IRREP_NAMES], axis=1)
r2_full = _ridge_test_r2(full[train_idx], y_tr,
full[val_idx], y_va,
full[test_idx], y_te)
cells.append(r2_full)
print(f" [{target} / full_concat] R² = {r2_full:.4f}")
rows.append((target,) + tuple(f"{c:.4f}" for c in cells))
out = Path(args.out_csv)
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w") as fp:
for row in rows:
fp.write(",".join(str(x) for x in row) + "\n")
print(f"\n[ok] wrote {out}")
if __name__ == "__main__":
main()