"""Within-isomer R² audit.
Pooled R² on QM9 HOMO-LUMO gap is dominated by molecule-size variance:
gap depends strongly on number of atoms, formula, and conjugation length,
so any model that captures gross size/composition will score R² > 0.95.
The interesting test of a model's chemistry-aware predictive power is
*within-isomer* R², how well does the method distinguish constitutional
isomers of the same molecular formula?
This script:
1. Loads test-set predictions for each method (from per-method JSONs
that contain `predictions` arrays, emitted by train_starg.py and
train_baseline_*.py when --save_predictions is set).
2. Groups molecules by molecular formula.
3. For groups of size >= MIN_GROUP_SIZE, computes within-group R².
4. Reports method × target table of (pooled R², within-isomer R²)
plus the formula-weighted mean R² across groups.
If the within-isomer R² collapses to a small number for every method,
that's the headline reframing of the pooled-R² horse race: the apparent
ENN advantage is largely a size-prediction effect.
If the within-isomer R² for an ENN remains substantial while ★_G
collapses, that's also informative, it means the ENN's bond-topology
input carries chemistry signal that the (14, |G|) molecule-level summary
does not, and the right reading is "input-information gap, not algebra
gap."
Usage:
python isomer_audit.py --results_dir results/ --qm9_dir ~/data/qm9/dsgdb9nsd \
--target gap --methods starg_ridge,starg_neural,mace,e3nn,schnet
"""
from __future__ import annotations
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
MIN_GROUP_SIZE = 5 # only compute within-group R² for formulas with >=5 isomers
def molecular_formula(Z) -> Tuple[Tuple[int, int], ...]:
"""Return a hashable canonical formula key: tuple of (Z, count) pairs
sorted by Z. e.g. (1,8), (6,2), (8,2) for ethanediol C2H8O2."""
counts = Counter(int(z) for z in Z)
return tuple(sorted(counts.items()))
def per_group_r2(y_true: np.ndarray, y_pred: np.ndarray,
groups: List[int]) -> Tuple[Dict[int, float], float, int]:
"""Compute R² inside each group and a sample-weighted aggregate.
Returns:
r2_by_group : dict {group_index → R²}
weighted_mean_r2 : float, weighted by group size
n_groups : count of groups with size >= MIN_GROUP_SIZE
"""
r2_by_group: Dict[int, float] = {}
bucket: Dict[int, List[int]] = defaultdict(list)
for idx, g in enumerate(groups):
bucket[g].append(idx)
weighted_sum = 0.0
total_weight = 0
for g, idxs in bucket.items():
if len(idxs) < MIN_GROUP_SIZE:
continue
yt = y_true[idxs]
yp = y_pred[idxs]
ss_tot = float(np.sum((yt - yt.mean()) ** 2))
if ss_tot < 1e-12:
continue # degenerate group, all targets identical
ss_res = float(np.sum((yt - yp) ** 2))
r2 = 1.0 - ss_res / ss_tot
r2_by_group[g] = r2
weighted_sum += r2 * len(idxs)
total_weight += len(idxs)
weighted_mean = weighted_sum / max(total_weight, 1)
return r2_by_group, weighted_mean, len(r2_by_group)
def load_predictions(method_dir: Path, target: str, seeds=(0, 1, 2)
) -> Dict[int, Dict]:
"""Load per-seed result JSONs that contain a 'predictions' array."""
out = {}
for seed in seeds:
f = method_dir / target / f"seed{seed}.json"
if not f.exists():
continue
with open(f) as fp:
data = json.load(fp)
if "predictions" not in data:
print(f" [skip] {f} has no 'predictions' field")
continue
out[seed] = data
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--results_dir", required=True)
ap.add_argument("--qm9_dir", required=True)
ap.add_argument("--target", default="gap")
ap.add_argument("--methods",
default="starg_ridge,starg_neural,mlp_standard,"
"mlp_invariant,mlp_augmented,schnet,e3nn,mace")
ap.add_argument("--max_molecules", type=int, default=None)
ap.add_argument("--out_csv", default=None,
help="optional CSV path; default <results_dir>/isomer_audit.csv")
args = ap.parse_args()
target = args.target
if target not in PROPERTY_INDEX and target not in ("mu_vector", "alpha_tensor"):
raise ValueError(f"unknown target: {target}")
results_dir = Path(args.results_dir)
out_csv = Path(args.out_csv) if args.out_csv else results_dir / "isomer_audit.csv"
# Load QM9 once to recover the molecular formulas. The per-seed split
# is reproducible from qm9_split(seed) so we can map test-set indices
# back to formulas.
print(f"[load] QM9 from {args.qm9_dir}")
ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
samples = [ds[i] for i in range(len(ds))]
formulas = [molecular_formula(s.Z) for s in samples]
formula_to_id = {}
formula_ids = []
for f in formulas:
if f not in formula_to_id:
formula_to_id[f] = len(formula_to_id)
formula_ids.append(formula_to_id[f])
formula_ids = np.array(formula_ids, dtype=int)
print(f"[load] {len(samples)} molecules, "
f"{len(formula_to_id)} distinct formulas")
if target in PROPERTY_INDEX:
y = np.array([s.properties[PROPERTY_INDEX[target]] for s in samples])
else:
# Tensor / vector targets, left for future extensions
raise NotImplementedError(
"isomer audit currently scalar-only; tensor targets require a "
"per-component group R² which is straightforward to add later."
)
methods = args.methods.split(",")
rows = [("method", "seed", "n_test_total", "pooled_r2",
"n_groups_audited", "n_test_in_groups",
"within_isomer_r2_weighted_mean")]
for method in methods:
method_dir = results_dir / method
if not method_dir.is_dir():
print(f"[skip] {method}: results dir missing ({method_dir})")
continue
seed_results = load_predictions(method_dir, target)
if not seed_results:
print(f"[skip] {method}: no per-seed predictions found "
f"(re-run training with --save_predictions)")
continue
for seed, data in seed_results.items():
preds = np.asarray(data["predictions"])
# Use the test_idx the trainer recorded if available; some
# baselines (e.g. PyG-SchNet) use a different split convention
# from qm9_split, so the embedded test_idx is authoritative.
if "test_idx" in data:
test_idx = np.asarray(data["test_idx"], dtype=int)
else:
_, _, test_idx = qm9_split(len(samples), seed=seed)
y_test = y[test_idx]
test_groups = formula_ids[test_idx]
ss_res = float(np.sum((y_test - preds) ** 2))
ss_tot = float(np.sum((y_test - y_test.mean()) ** 2))
pooled = 1.0 - ss_res / max(ss_tot, 1e-12)
r2_by_g, mean_r2, n_groups = per_group_r2(
y_test, preds, list(test_groups)
)
n_in_groups = sum(
len(np.where(test_groups == g)[0])
for g in r2_by_g
)
rows.append((method, seed, len(test_idx), pooled,
n_groups, n_in_groups, mean_r2))
print(f"[{method}/seed{seed}] pooled R²={pooled:.4f} "
f"within-isomer R²={mean_r2:.4f} "
f"({n_groups} formulas, {n_in_groups} test molecules)")
with open(out_csv, "w") as fp:
for row in rows:
fp.write(",".join(str(x) for x in row) + "\n")
print(f"\n[ok] wrote {out_csv}")
if __name__ == "__main__":
main()