"""Pareto-frontier figure: R² vs parameter count on QM9 HOMO-LUMO gap.
Plots R² vs n_params on log-log axes for every method, with two panels:
(a) pooled R² (the "advertised" QM9 scoreboard)
(b) within-isomer R² (the controlled-for-size comparison)
The visual story: ★_G ridge sits at the leftmost point of the parameter
axis, MACE at the rightmost, and the within-isomer panel shows that the
apparent ENN advantage in pooled R² narrows substantially when
size-prediction signal is removed.
Inputs:
- results/summary.csv produced by eval_collect.py (pooled R²)
- results/isomer_audit.csv produced by isomer_audit.py (within-isomer R²)
Output:
- results/pareto_figure.pdf
- results/pareto_figure.png
Usage:
python pareto_figure.py --target gap --out_dir results/
"""
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
# Method → (display name, color), order matters for legend
METHOD_STYLE = {
"starg_ridge": ("★_G ridge", "#1f77b4"),
"starg_neural": ("★_G neural", "#2ca02c"),
"mlp_standard": ("MLP standard", "#7f7f7f"),
"mlp_invariant": ("MLP invariant", "#bcbd22"),
"mlp_augmented": ("MLP augmented", "#d62728"),
"schnet": ("SchNet", "#ff7f0e"),
"e3nn": ("e3nn (SE3)", "#9467bd"),
"mace": ("MACE", "#8c564b"),
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--target", default="gap")
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--summary_csv", default="results/summary.csv")
ap.add_argument("--isomer_csv", default="results/isomer_audit.csv")
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
summary = pd.read_csv(args.summary_csv)
summary = summary[summary["target"] == args.target]
have_isomer = Path(args.isomer_csv).exists()
if have_isomer:
isomer = pd.read_csv(args.isomer_csv)
isomer = isomer[isomer.get("seed", 0).astype(int) == 0] if "seed" in isomer else isomer
# aggregate across seeds: weighted mean of within_isomer_r2
if "seed" in isomer.columns:
isomer_agg = (
isomer.groupby("method")["within_isomer_r2_weighted_mean"]
.mean().reset_index()
.rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"})
)
else:
isomer_agg = isomer.rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"})
fig, axes = plt.subplots(1, 2 if have_isomer else 1,
figsize=(12 if have_isomer else 7, 5),
sharey=False)
if not have_isomer:
axes = [axes]
# Panel (a): pooled R² vs n_params
axL = axes[0]
for method, (label, color) in METHOD_STYLE.items():
row = summary[summary["method"] == method]
if row.empty:
continue
x = float(row["n_params"].iloc[0])
y = float(row["r2_mean"].iloc[0])
yerr = float(row["r2_std"].iloc[0]) if not pd.isna(row["r2_std"].iloc[0]) else 0
axL.errorbar(x, y, yerr=yerr, fmt="o", markersize=10,
color=color, label=label, capsize=4, elinewidth=1.2)
axL.annotate(label, (x, y),
xytext=(8, 4), textcoords="offset points",
fontsize=9, color=color)
axL.set_xscale("log")
axL.set_xlabel("Trainable parameters (log)", fontsize=11)
axL.set_ylabel(f"Pooled test R² ({args.target})", fontsize=11)
axL.set_title("(a) Pooled R²: advertised QM9 score", fontsize=12)
axL.grid(True, which="both", alpha=0.25, linestyle=":")
axL.set_ylim(-0.05, 1.02)
# Panel (b): within-isomer R² vs n_params
if have_isomer:
axR = axes[1]
# Build a combined frame
merged = summary.merge(isomer_agg, on="method", how="left")
for method, (label, color) in METHOD_STYLE.items():
row = merged[merged["method"] == method]
if row.empty or pd.isna(row["isomer_r2"].iloc[0]):
continue
x = float(row["n_params"].iloc[0])
y = float(row["isomer_r2"].iloc[0])
axR.scatter(x, y, s=120, color=color, label=label,
edgecolor="black", linewidth=0.6, zorder=3)
axR.annotate(label, (x, y),
xytext=(8, 4), textcoords="offset points",
fontsize=9, color=color)
axR.set_xscale("log")
axR.set_xlabel("Trainable parameters (log)", fontsize=11)
axR.set_ylabel(f"Within-isomer test R² ({args.target})", fontsize=11)
axR.set_title("(b) Within-isomer R²: size-prediction signal removed",
fontsize=12)
axR.grid(True, which="both", alpha=0.25, linestyle=":")
fig.suptitle(
f"Parameter-efficiency vs predictive power on QM9 {args.target}",
fontsize=13, y=1.02,
)
fig.tight_layout()
pdf_path = out_dir / f"pareto_{args.target}.pdf"
png_path = out_dir / f"pareto_{args.target}.png"
fig.savefig(pdf_path, bbox_inches="tight")
fig.savefig(png_path, bbox_inches="tight", dpi=160)
plt.close(fig)
print(f"[ok] wrote {pdf_path} and {png_path}")
if __name__ == "__main__":
main()