"""Pareto-frontier figure: R² vs parameter count on QM9 HOMO-LUMO gap. Plots R² vs n_params on log-log axes for every method, with two panels: (a) pooled R² (the "advertised" QM9 scoreboard) (b) within-isomer R² (the controlled-for-size comparison) The visual story: ★_G ridge sits at the leftmost point of the parameter axis, MACE at the rightmost, and the within-isomer panel shows that the apparent ENN advantage in pooled R² narrows substantially when size-prediction signal is removed. Inputs: - results/summary.csv produced by eval_collect.py (pooled R²) - results/isomer_audit.csv produced by isomer_audit.py (within-isomer R²) Output: - results/pareto_figure.pdf - results/pareto_figure.png Usage: python pareto_figure.py --target gap --out_dir results/ """ from __future__ import annotations import argparse from pathlib import Path import pandas as pd import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np # Method → (display name, color), order matters for legend METHOD_STYLE = { "starg_ridge": ("★_G ridge", "#1f77b4"), "starg_neural": ("★_G neural", "#2ca02c"), "mlp_standard": ("MLP standard", "#7f7f7f"), "mlp_invariant": ("MLP invariant", "#bcbd22"), "mlp_augmented": ("MLP augmented", "#d62728"), "schnet": ("SchNet", "#ff7f0e"), "e3nn": ("e3nn (SE3)", "#9467bd"), "mace": ("MACE", "#8c564b"), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--target", default="gap") ap.add_argument("--out_dir", default="results/") ap.add_argument("--summary_csv", default="results/summary.csv") ap.add_argument("--isomer_csv", default="results/isomer_audit.csv") args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) summary = pd.read_csv(args.summary_csv) summary = summary[summary["target"] == args.target] have_isomer = Path(args.isomer_csv).exists() if have_isomer: isomer = pd.read_csv(args.isomer_csv) isomer = isomer[isomer.get("seed", 0).astype(int) == 0] if "seed" in isomer else isomer # aggregate across seeds: weighted mean of within_isomer_r2 if "seed" in isomer.columns: isomer_agg = ( isomer.groupby("method")["within_isomer_r2_weighted_mean"] .mean().reset_index() .rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"}) ) else: isomer_agg = isomer.rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"}) fig, axes = plt.subplots(1, 2 if have_isomer else 1, figsize=(12 if have_isomer else 7, 5), sharey=False) if not have_isomer: axes = [axes] # Panel (a): pooled R² vs n_params axL = axes[0] for method, (label, color) in METHOD_STYLE.items(): row = summary[summary["method"] == method] if row.empty: continue x = float(row["n_params"].iloc[0]) y = float(row["r2_mean"].iloc[0]) yerr = float(row["r2_std"].iloc[0]) if not pd.isna(row["r2_std"].iloc[0]) else 0 axL.errorbar(x, y, yerr=yerr, fmt="o", markersize=10, color=color, label=label, capsize=4, elinewidth=1.2) axL.annotate(label, (x, y), xytext=(8, 4), textcoords="offset points", fontsize=9, color=color) axL.set_xscale("log") axL.set_xlabel("Trainable parameters (log)", fontsize=11) axL.set_ylabel(f"Pooled test R² ({args.target})", fontsize=11) axL.set_title("(a) Pooled R²: advertised QM9 score", fontsize=12) axL.grid(True, which="both", alpha=0.25, linestyle=":") axL.set_ylim(-0.05, 1.02) # Panel (b): within-isomer R² vs n_params if have_isomer: axR = axes[1] # Build a combined frame merged = summary.merge(isomer_agg, on="method", how="left") for method, (label, color) in METHOD_STYLE.items(): row = merged[merged["method"] == method] if row.empty or pd.isna(row["isomer_r2"].iloc[0]): continue x = float(row["n_params"].iloc[0]) y = float(row["isomer_r2"].iloc[0]) axR.scatter(x, y, s=120, color=color, label=label, edgecolor="black", linewidth=0.6, zorder=3) axR.annotate(label, (x, y), xytext=(8, 4), textcoords="offset points", fontsize=9, color=color) axR.set_xscale("log") axR.set_xlabel("Trainable parameters (log)", fontsize=11) axR.set_ylabel(f"Within-isomer test R² ({args.target})", fontsize=11) axR.set_title("(b) Within-isomer R²: size-prediction signal removed", fontsize=12) axR.grid(True, which="both", alpha=0.25, linestyle=":") fig.suptitle( f"Parameter-efficiency vs predictive power on QM9 {args.target}", fontsize=13, y=1.02, ) fig.tight_layout() pdf_path = out_dir / f"pareto_{args.target}.pdf" png_path = out_dir / f"pareto_{args.target}.png" fig.savefig(pdf_path, bbox_inches="tight") fig.savefig(png_path, bbox_inches="tight", dpi=160) plt.close(fig) print(f"[ok] wrote {pdf_path} and {png_path}") if __name__ == "__main__": main()