tensor-group-sym / python / large_scale / train_baseline_schnet.py
train_baseline_schnet.py
Raw
"""SchNet baseline (invariant ENN).

SchNet (Schütt et al., 2017) is the standard invariant-ENN baseline on QM9.
It builds a continuous-filter convolutional network on atomic graphs, uses
distance-based features only (no angular/equivariant features), and predicts
scalar properties. We use the schnetpack reference implementation pinned to
v2.1.1 in requirements.txt.

Usage:
    python train_baseline_schnet.py --target gap --qm9_dir /path/to/qm9 \
        --seed 0 --out_dir results/

This script writes results/schnet/<target>/seed<k>.json with the same schema
as the ★_G and MLP scripts.
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch

try:
    import schnetpack as spk
    import schnetpack.transform as trn
    from torch.utils.data import DataLoader
    SCHNET_AVAILABLE = True
except ImportError:
    SCHNET_AVAILABLE = False


SCHNET_TARGET_KEYS = {
    "mu": "dipole_moment",
    "alpha": "isotropic_polarizability",
    "homo": "homo",
    "lumo": "lumo",
    "gap": "gap",
    "r2": "electronic_spatial_extent",
    "zpve": "zpve",
    "U0": "energy_U0",
    "U": "energy_U",
    "H": "enthalpy_H",
    "G": "free_energy",
    "Cv": "heat_capacity",
}


def main():
    if not SCHNET_AVAILABLE:
        raise RuntimeError(
            "schnetpack is not installed. Run: pip install schnetpack==2.1.1"
        )

    ap = argparse.ArgumentParser()
    ap.add_argument("--target", required=True, help="QM9 property name")
    ap.add_argument("--qm9_dir", required=True,
                    help="path to QM9 raw directory (or schnetpack cache root)")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--cutoff", type=float, default=5.0)
    ap.add_argument("--n_features", type=int, default=128)
    ap.add_argument("--n_interactions", type=int, default=6)
    ap.add_argument("--n_rbf", type=int, default=20)
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--lr", type=float, default=5e-4)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--n_train", type=int, default=110000)
    ap.add_argument("--n_val", type=int, default=10000)
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    out_dir = Path(args.out_dir) / "schnet" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    spk_target = SCHNET_TARGET_KEYS[args.target]

    # schnetpack QM9 dataset wrapper handles the standard 110k/10k/13k split.
    # remove_uncharacterized=True matches the 130,831-molecule subset that
    # PyG's QM9 v3 also uses (3054 uncharacterized molecules from
    # Ramakrishnan 2014 are excluded for consistency).
    qm9 = spk.datasets.QM9(
        datapath=str(Path(args.qm9_dir) / f"qm9_seed{args.seed}.db"),
        batch_size=args.batch,
        num_train=args.n_train,
        num_val=args.n_val,
        remove_uncharacterized=True,
        transforms=[
            trn.ASENeighborList(cutoff=args.cutoff),
            trn.RemoveOffsets(spk_target, remove_mean=True, remove_atomrefs=False),
            trn.CastTo32(),
        ],
        property_units={spk_target: "Ha" if args.target in ("homo", "lumo", "gap", "zpve", "U0", "U", "H", "G") else None},
        num_workers=4,
        pin_memory=True,
    )
    qm9.prepare_data()
    qm9.setup()

    # Build SchNet model
    pairwise_distance = spk.atomistic.PairwiseDistances()
    radial_basis = spk.nn.GaussianRBF(n_rbf=args.n_rbf, cutoff=args.cutoff)
    schnet = spk.representation.SchNet(
        n_atom_basis=args.n_features,
        n_interactions=args.n_interactions,
        radial_basis=radial_basis,
        cutoff_fn=spk.nn.CosineCutoff(args.cutoff),
    )
    pred_head = spk.atomistic.Atomwise(n_in=args.n_features, output_key=spk_target)
    model = spk.model.NeuralNetworkPotential(
        representation=schnet,
        input_modules=[pairwise_distance],
        output_modules=[pred_head],
        postprocessors=[trn.CastTo64(), trn.AddOffsets(spk_target, add_mean=True)],
    )

    # Note: schnetpack wraps each metric in nn.ModuleDict, so every entry
    # must be an nn.Module (lambdas are rejected). Use only L1Loss here;
    # we recompute RMSE/R² ourselves from saved predictions below.
    output = spk.task.ModelOutput(
        name=spk_target,
        loss_fn=torch.nn.MSELoss(),
        loss_weight=1.0,
        metrics={"MAE": torch.nn.L1Loss()},
    )
    task = spk.task.AtomisticTask(
        model=model,
        outputs=[output],
        optimizer_cls=torch.optim.Adam,
        optimizer_args={"lr": args.lr},
    )

    import pytorch_lightning as pl
    trainer = pl.Trainer(
        accelerator="gpu" if device.type == "cuda" else "cpu",
        devices=1,
        max_epochs=args.epochs,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor=f"val_{spk_target}_MAE", patience=20),
            pl.callbacks.ModelCheckpoint(
                monitor=f"val_{spk_target}_MAE", save_top_k=1, mode="min",
                dirpath=str(out_dir / f"ckpt_seed{args.seed}"),
            ),
        ],
        enable_progress_bar=True,
        log_every_n_steps=50,
    )

    t0 = time.time()
    trainer.fit(task, datamodule=qm9)
    train_time = time.time() - t0

    # Evaluate on test split
    test_results = trainer.test(task, datamodule=qm9, ckpt_path="best")[0]
    mae = float(test_results.get(f"test_{spk_target}_MAE", 0.0))

    # R2 requires running predictions; do a final pass
    task.eval()
    preds, labels = [], []
    for batch in qm9.test_dataloader():
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
        with torch.no_grad():
            out = task(batch)
        preds.append(out[spk_target].cpu().numpy())
        labels.append(batch[spk_target].cpu().numpy())
    preds = np.concatenate(preds); labels = np.concatenate(labels)
    ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
    r2 = float(1 - ss_res / ss_tot)
    rmse = float(np.sqrt(((labels - preds) ** 2).mean()))

    result = {
        "method": "schnet",
        "target": args.target,
        "seed": args.seed,
        "n_train": args.n_train,
        "n_val": args.n_val,
        "n_features": args.n_features,
        "n_interactions": args.n_interactions,
        "n_params": sum(p.numel() for p in model.parameters()),
        "train_time_s": train_time,
        "r2": r2, "rmse": rmse, "mae": mae,
        "predictions": preds.tolist(),
        "labels": labels.tolist(),
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] schnet R2={r2:.4f}  MAE={mae:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()