tensor-group-sym / python / large_scale / pareto_figure.py
pareto_figure.py
Raw
"""Pareto-frontier figure: R² vs parameter count on QM9 HOMO-LUMO gap.

Plots R² vs n_params on log-log axes for every method, with two panels:
  (a) pooled R² (the "advertised" QM9 scoreboard)
  (b) within-isomer R² (the controlled-for-size comparison)

The visual story: ★_G ridge sits at the leftmost point of the parameter
axis, MACE at the rightmost, and the within-isomer panel shows that the
apparent ENN advantage in pooled R² narrows substantially when
size-prediction signal is removed.

Inputs:
  - results/summary.csv produced by eval_collect.py (pooled R²)
  - results/isomer_audit.csv produced by isomer_audit.py (within-isomer R²)

Output:
  - results/pareto_figure.pdf
  - results/pareto_figure.png

Usage:
    python pareto_figure.py --target gap --out_dir results/
"""

from __future__ import annotations

import argparse
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np


# Method → (display name, color), order matters for legend
METHOD_STYLE = {
    "starg_ridge":    ("★_G ridge",      "#1f77b4"),
    "starg_neural":   ("★_G neural",     "#2ca02c"),
    "mlp_standard":   ("MLP standard",   "#7f7f7f"),
    "mlp_invariant":  ("MLP invariant",  "#bcbd22"),
    "mlp_augmented":  ("MLP augmented",  "#d62728"),
    "schnet":         ("SchNet",         "#ff7f0e"),
    "e3nn":           ("e3nn (SE3)",     "#9467bd"),
    "mace":           ("MACE",           "#8c564b"),
}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--target", default="gap")
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--summary_csv", default="results/summary.csv")
    ap.add_argument("--isomer_csv", default="results/isomer_audit.csv")
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    summary = pd.read_csv(args.summary_csv)
    summary = summary[summary["target"] == args.target]

    have_isomer = Path(args.isomer_csv).exists()
    if have_isomer:
        isomer = pd.read_csv(args.isomer_csv)
        isomer = isomer[isomer.get("seed", 0).astype(int) == 0] if "seed" in isomer else isomer
        # aggregate across seeds: weighted mean of within_isomer_r2
        if "seed" in isomer.columns:
            isomer_agg = (
                isomer.groupby("method")["within_isomer_r2_weighted_mean"]
                .mean().reset_index()
                .rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"})
            )
        else:
            isomer_agg = isomer.rename(columns={"within_isomer_r2_weighted_mean": "isomer_r2"})

    fig, axes = plt.subplots(1, 2 if have_isomer else 1,
                             figsize=(12 if have_isomer else 7, 5),
                             sharey=False)
    if not have_isomer:
        axes = [axes]

    # Panel (a): pooled R² vs n_params
    axL = axes[0]
    for method, (label, color) in METHOD_STYLE.items():
        row = summary[summary["method"] == method]
        if row.empty:
            continue
        x = float(row["n_params"].iloc[0])
        y = float(row["r2_mean"].iloc[0])
        yerr = float(row["r2_std"].iloc[0]) if not pd.isna(row["r2_std"].iloc[0]) else 0
        axL.errorbar(x, y, yerr=yerr, fmt="o", markersize=10,
                     color=color, label=label, capsize=4, elinewidth=1.2)
        axL.annotate(label, (x, y),
                     xytext=(8, 4), textcoords="offset points",
                     fontsize=9, color=color)

    axL.set_xscale("log")
    axL.set_xlabel("Trainable parameters (log)", fontsize=11)
    axL.set_ylabel(f"Pooled test R² ({args.target})", fontsize=11)
    axL.set_title("(a) Pooled R²: advertised QM9 score", fontsize=12)
    axL.grid(True, which="both", alpha=0.25, linestyle=":")
    axL.set_ylim(-0.05, 1.02)

    # Panel (b): within-isomer R² vs n_params
    if have_isomer:
        axR = axes[1]
        # Build a combined frame
        merged = summary.merge(isomer_agg, on="method", how="left")
        for method, (label, color) in METHOD_STYLE.items():
            row = merged[merged["method"] == method]
            if row.empty or pd.isna(row["isomer_r2"].iloc[0]):
                continue
            x = float(row["n_params"].iloc[0])
            y = float(row["isomer_r2"].iloc[0])
            axR.scatter(x, y, s=120, color=color, label=label,
                        edgecolor="black", linewidth=0.6, zorder=3)
            axR.annotate(label, (x, y),
                         xytext=(8, 4), textcoords="offset points",
                         fontsize=9, color=color)
        axR.set_xscale("log")
        axR.set_xlabel("Trainable parameters (log)", fontsize=11)
        axR.set_ylabel(f"Within-isomer test R² ({args.target})", fontsize=11)
        axR.set_title("(b) Within-isomer R²: size-prediction signal removed",
                      fontsize=12)
        axR.grid(True, which="both", alpha=0.25, linestyle=":")

    fig.suptitle(
        f"Parameter-efficiency vs predictive power on QM9 {args.target}",
        fontsize=13, y=1.02,
    )
    fig.tight_layout()

    pdf_path = out_dir / f"pareto_{args.target}.pdf"
    png_path = out_dir / f"pareto_{args.target}.png"
    fig.savefig(pdf_path, bbox_inches="tight")
    fig.savefig(png_path, bbox_inches="tight", dpi=160)
    plt.close(fig)
    print(f"[ok] wrote {pdf_path} and {png_path}")


if __name__ == "__main__":
    main()