"""Per-irrep R² decomposition. For each (target, irrep ρ) cell, train a ridge regressor on *only* the features that project onto irrep ρ, the per-irrep Fourier power features produced by the generalized Fourier transform, and report the resulting test R². The output is a target × irrep table: | | A1 | A2 | E | T1 | T2 | |-------|------|------|------|------|------| | gap | 0.62 | 0.62 | 0.30 | 0.40 | 0.21 | | mu_x | 0.01 | 0.01 | 0.01 | 0.04 | 0.00 | | ... This decomposition is uniquely available to the ★_G framework because the per-irrep features are produced by the generalized Fourier transform of the molecule tensor over the chosen group. ENNs can be probed for similar information via attention or activation analysis, but they do not produce a closed-form per-irrep R², the readout is end-to-end. Usage: python per_irrep_audit.py --qm9_dir ~/data/qm9/dsgdb9nsd \ --group octahedral --targets gap,alpha,mu,zpve --seed 0 \ --out_csv results/per_irrep_audit.csv """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, List import numpy as np import torch from sklearn.linear_model import RidgeCV from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX from data.featurizers import octahedral_features, cyclic_angular_features from starg_torch.algebra import GroupAlgebra from starg_torch.octahedral import octahedral_rotations, octahedral_irreps # Octahedral irreps in the order assembled by octahedral_irreps() OCT_IRREP_NAMES = ["A1", "A2", "E", "T1", "T2"] OCT_IRREP_DIMS = [1, 1, 2, 3, 3] def _per_irrep_power_features(X: np.ndarray, irrep_dims=OCT_IRREP_DIMS, ) -> Dict[str, np.ndarray]: """Given the (N, n_feat, |G|) molecule tensor, return per-irrep power features: for each irrep ρ of dim d, slice the d² Fourier-block columns out of the F_G transform and report the squared Frobenius power per feature row, yielding shape (N, n_feat) per irrep. The columns of the generalized Fourier matrix F_G are arranged by irrep block per `octahedral_irreps()`; each irrep ρ contributes d² consecutive columns. """ from starg_torch.octahedral import octahedral_irreps F, _ = octahedral_irreps() F = np.asarray(F, dtype=np.complex128) # Project: shape (N, n_feat, |G|) @ F gives (N, n_feat, |G|) in Fourier coords Xh = np.einsum("nfk,kj->nfj", X.astype(np.complex128), F) out = {} cursor = 0 for name, d in zip(OCT_IRREP_NAMES, irrep_dims): block_size = d * d block = Xh[:, :, cursor : cursor + block_size] # Squared Frobenius power per (sample, feature row) power = (np.abs(block) ** 2).sum(axis=2) out[name] = power.real.astype(np.float32) cursor += block_size return out def _stack_features(samples) -> np.ndarray: feats = [octahedral_features(s) for s in samples] return np.stack(feats, axis=0) def _build_target(samples, target: str) -> np.ndarray: if target in PROPERTY_INDEX: return np.array([s.properties[PROPERTY_INDEX[target]] for s in samples]) raise NotImplementedError( f"per-irrep audit currently only supports scalar QM9 targets; " f"got '{target}'." ) def _ridge_test_r2(X_tr: np.ndarray, y_tr: np.ndarray, X_va: np.ndarray, y_va: np.ndarray, X_te: np.ndarray, y_te: np.ndarray) -> float: # Standardize on training fold mu = X_tr.mean(axis=0) sig = X_tr.std(axis=0) + 1e-8 X_tr = (X_tr - mu) / sig X_va = (X_va - mu) / sig X_te = (X_te - mu) / sig model = RidgeCV(alphas=np.logspace(-3, 3, 7)) model.fit(np.vstack([X_tr, X_va]), np.concatenate([y_tr, y_va])) y_pred = model.predict(X_te) ss_res = float(np.sum((y_te - y_pred) ** 2)) ss_tot = float(np.sum((y_te - y_te.mean()) ** 2)) return 1.0 - ss_res / max(ss_tot, 1e-12) def main(): ap = argparse.ArgumentParser() ap.add_argument("--qm9_dir", required=True) ap.add_argument("--group", default="octahedral", choices=("octahedral",), help="per-irrep audit currently wired for octahedral; " "cyclic only has 1-d irreps so the table is trivial.") ap.add_argument("--targets", default="gap,alpha,mu,zpve") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--max_molecules", type=int, default=None) ap.add_argument("--out_csv", default="results/per_irrep_audit.csv") args = ap.parse_args() print(f"[load] QM9 from {args.qm9_dir}") ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules) samples = [ds[i] for i in range(len(ds))] print(f"[load] {len(samples)} molecules") print(f"[feat] computing octahedral features (n_g=24)") X = _stack_features(samples) # (N, 14, 24) print(f"[feat] X shape: {X.shape}") print("[feat] decomposing per-irrep power features") per_irrep = _per_irrep_power_features(X) # dict name → (N, 14) train_idx, val_idx, test_idx = qm9_split(len(samples), seed=args.seed) targets = args.targets.split(",") rows = [("target",) + tuple(OCT_IRREP_NAMES) + ("full_concat",)] for target in targets: y = _build_target(samples, target) y_tr = y[train_idx]; y_va = y[val_idx]; y_te = y[test_idx] cells: List[float] = [] for name in OCT_IRREP_NAMES: P = per_irrep[name] r2 = _ridge_test_r2(P[train_idx], y_tr, P[val_idx], y_va, P[test_idx], y_te) cells.append(r2) print(f" [{target} / {name}] R² = {r2:.4f}") # All-irreps concatenated as the upper bound for context full = np.concatenate([per_irrep[n] for n in OCT_IRREP_NAMES], axis=1) r2_full = _ridge_test_r2(full[train_idx], y_tr, full[val_idx], y_va, full[test_idx], y_te) cells.append(r2_full) print(f" [{target} / full_concat] R² = {r2_full:.4f}") rows.append((target,) + tuple(f"{c:.4f}" for c in cells)) out = Path(args.out_csv) out.parent.mkdir(parents=True, exist_ok=True) with open(out, "w") as fp: for row in rows: fp.write(",".join(str(x) for x in row) + "\n") print(f"\n[ok] wrote {out}") if __name__ == "__main__": main()