"""Within-isomer R² audit. Pooled R² on QM9 HOMO-LUMO gap is dominated by molecule-size variance: gap depends strongly on number of atoms, formula, and conjugation length, so any model that captures gross size/composition will score R² > 0.95. The interesting test of a model's chemistry-aware predictive power is *within-isomer* R², how well does the method distinguish constitutional isomers of the same molecular formula? This script: 1. Loads test-set predictions for each method (from per-method JSONs that contain `predictions` arrays, emitted by train_starg.py and train_baseline_*.py when --save_predictions is set). 2. Groups molecules by molecular formula. 3. For groups of size >= MIN_GROUP_SIZE, computes within-group R². 4. Reports method × target table of (pooled R², within-isomer R²) plus the formula-weighted mean R² across groups. If the within-isomer R² collapses to a small number for every method, that's the headline reframing of the pooled-R² horse race: the apparent ENN advantage is largely a size-prediction effect. If the within-isomer R² for an ENN remains substantial while ★_G collapses, that's also informative, it means the ENN's bond-topology input carries chemistry signal that the (14, |G|) molecule-level summary does not, and the right reading is "input-information gap, not algebra gap." Usage: python isomer_audit.py --results_dir results/ --qm9_dir ~/data/qm9/dsgdb9nsd \ --target gap --methods starg_ridge,starg_neural,mace,e3nn,schnet """ from __future__ import annotations import argparse import json from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Tuple import numpy as np from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX MIN_GROUP_SIZE = 5 # only compute within-group R² for formulas with >=5 isomers def molecular_formula(Z) -> Tuple[Tuple[int, int], ...]: """Return a hashable canonical formula key: tuple of (Z, count) pairs sorted by Z. e.g. (1,8), (6,2), (8,2) for ethanediol C2H8O2.""" counts = Counter(int(z) for z in Z) return tuple(sorted(counts.items())) def per_group_r2(y_true: np.ndarray, y_pred: np.ndarray, groups: List[int]) -> Tuple[Dict[int, float], float, int]: """Compute R² inside each group and a sample-weighted aggregate. Returns: r2_by_group : dict {group_index → R²} weighted_mean_r2 : float, weighted by group size n_groups : count of groups with size >= MIN_GROUP_SIZE """ r2_by_group: Dict[int, float] = {} bucket: Dict[int, List[int]] = defaultdict(list) for idx, g in enumerate(groups): bucket[g].append(idx) weighted_sum = 0.0 total_weight = 0 for g, idxs in bucket.items(): if len(idxs) < MIN_GROUP_SIZE: continue yt = y_true[idxs] yp = y_pred[idxs] ss_tot = float(np.sum((yt - yt.mean()) ** 2)) if ss_tot < 1e-12: continue # degenerate group, all targets identical ss_res = float(np.sum((yt - yp) ** 2)) r2 = 1.0 - ss_res / ss_tot r2_by_group[g] = r2 weighted_sum += r2 * len(idxs) total_weight += len(idxs) weighted_mean = weighted_sum / max(total_weight, 1) return r2_by_group, weighted_mean, len(r2_by_group) def load_predictions(method_dir: Path, target: str, seeds=(0, 1, 2) ) -> Dict[int, Dict]: """Load per-seed result JSONs that contain a 'predictions' array.""" out = {} for seed in seeds: f = method_dir / target / f"seed{seed}.json" if not f.exists(): continue with open(f) as fp: data = json.load(fp) if "predictions" not in data: print(f" [skip] {f} has no 'predictions' field") continue out[seed] = data return out def main(): ap = argparse.ArgumentParser() ap.add_argument("--results_dir", required=True) ap.add_argument("--qm9_dir", required=True) ap.add_argument("--target", default="gap") ap.add_argument("--methods", default="starg_ridge,starg_neural,mlp_standard," "mlp_invariant,mlp_augmented,schnet,e3nn,mace") ap.add_argument("--max_molecules", type=int, default=None) ap.add_argument("--out_csv", default=None, help="optional CSV path; default /isomer_audit.csv") args = ap.parse_args() target = args.target if target not in PROPERTY_INDEX and target not in ("mu_vector", "alpha_tensor"): raise ValueError(f"unknown target: {target}") results_dir = Path(args.results_dir) out_csv = Path(args.out_csv) if args.out_csv else results_dir / "isomer_audit.csv" # Load QM9 once to recover the molecular formulas. The per-seed split # is reproducible from qm9_split(seed) so we can map test-set indices # back to formulas. print(f"[load] QM9 from {args.qm9_dir}") ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules) samples = [ds[i] for i in range(len(ds))] formulas = [molecular_formula(s.Z) for s in samples] formula_to_id = {} formula_ids = [] for f in formulas: if f not in formula_to_id: formula_to_id[f] = len(formula_to_id) formula_ids.append(formula_to_id[f]) formula_ids = np.array(formula_ids, dtype=int) print(f"[load] {len(samples)} molecules, " f"{len(formula_to_id)} distinct formulas") if target in PROPERTY_INDEX: y = np.array([s.properties[PROPERTY_INDEX[target]] for s in samples]) else: # Tensor / vector targets, left for future extensions raise NotImplementedError( "isomer audit currently scalar-only; tensor targets require a " "per-component group R² which is straightforward to add later." ) methods = args.methods.split(",") rows = [("method", "seed", "n_test_total", "pooled_r2", "n_groups_audited", "n_test_in_groups", "within_isomer_r2_weighted_mean")] for method in methods: method_dir = results_dir / method if not method_dir.is_dir(): print(f"[skip] {method}: results dir missing ({method_dir})") continue seed_results = load_predictions(method_dir, target) if not seed_results: print(f"[skip] {method}: no per-seed predictions found " f"(re-run training with --save_predictions)") continue for seed, data in seed_results.items(): preds = np.asarray(data["predictions"]) # Use the test_idx the trainer recorded if available; some # baselines (e.g. PyG-SchNet) use a different split convention # from qm9_split, so the embedded test_idx is authoritative. if "test_idx" in data: test_idx = np.asarray(data["test_idx"], dtype=int) else: _, _, test_idx = qm9_split(len(samples), seed=seed) y_test = y[test_idx] test_groups = formula_ids[test_idx] ss_res = float(np.sum((y_test - preds) ** 2)) ss_tot = float(np.sum((y_test - y_test.mean()) ** 2)) pooled = 1.0 - ss_res / max(ss_tot, 1e-12) r2_by_g, mean_r2, n_groups = per_group_r2( y_test, preds, list(test_groups) ) n_in_groups = sum( len(np.where(test_groups == g)[0]) for g in r2_by_g ) rows.append((method, seed, len(test_idx), pooled, n_groups, n_in_groups, mean_r2)) print(f"[{method}/seed{seed}] pooled R²={pooled:.4f} " f"within-isomer R²={mean_r2:.4f} " f"({n_groups} formulas, {n_in_groups} test molecules)") with open(out_csv, "w") as fp: for row in rows: fp.write(",".join(str(x) for x in row) + "\n") print(f"\n[ok] wrote {out_csv}") if __name__ == "__main__": main()