"""SchNet baseline (invariant ENN). SchNet (Schütt et al., 2017) is the standard invariant-ENN baseline on QM9. It builds a continuous-filter convolutional network on atomic graphs, uses distance-based features only (no angular/equivariant features), and predicts scalar properties. We use the schnetpack reference implementation pinned to v2.1.1 in requirements.txt. Usage: python train_baseline_schnet.py --target gap --qm9_dir /path/to/qm9 \ --seed 0 --out_dir results/ This script writes results/schnet//seed.json with the same schema as the ★_G and MLP scripts. """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch try: import schnetpack as spk import schnetpack.transform as trn from torch.utils.data import DataLoader SCHNET_AVAILABLE = True except ImportError: SCHNET_AVAILABLE = False SCHNET_TARGET_KEYS = { "mu": "dipole_moment", "alpha": "isotropic_polarizability", "homo": "homo", "lumo": "lumo", "gap": "gap", "r2": "electronic_spatial_extent", "zpve": "zpve", "U0": "energy_U0", "U": "energy_U", "H": "enthalpy_H", "G": "free_energy", "Cv": "heat_capacity", } def main(): if not SCHNET_AVAILABLE: raise RuntimeError( "schnetpack is not installed. Run: pip install schnetpack==2.1.1" ) ap = argparse.ArgumentParser() ap.add_argument("--target", required=True, help="QM9 property name") ap.add_argument("--qm9_dir", required=True, help="path to QM9 raw directory (or schnetpack cache root)") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--cutoff", type=float, default=5.0) ap.add_argument("--n_features", type=int, default=128) ap.add_argument("--n_interactions", type=int, default=6) ap.add_argument("--n_rbf", type=int, default=20) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--batch", type=int, default=64) ap.add_argument("--out_dir", default="results/") ap.add_argument("--n_train", type=int, default=110000) ap.add_argument("--n_val", type=int, default=10000) ap.add_argument("--device", default="cuda") args = ap.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") out_dir = Path(args.out_dir) / "schnet" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" spk_target = SCHNET_TARGET_KEYS[args.target] # schnetpack QM9 dataset wrapper handles the standard 110k/10k/13k split. # remove_uncharacterized=True matches the 130,831-molecule subset that # PyG's QM9 v3 also uses (3054 uncharacterized molecules from # Ramakrishnan 2014 are excluded for consistency). qm9 = spk.datasets.QM9( datapath=str(Path(args.qm9_dir) / f"qm9_seed{args.seed}.db"), batch_size=args.batch, num_train=args.n_train, num_val=args.n_val, remove_uncharacterized=True, transforms=[ trn.ASENeighborList(cutoff=args.cutoff), trn.RemoveOffsets(spk_target, remove_mean=True, remove_atomrefs=False), trn.CastTo32(), ], property_units={spk_target: "Ha" if args.target in ("homo", "lumo", "gap", "zpve", "U0", "U", "H", "G") else None}, num_workers=4, pin_memory=True, ) qm9.prepare_data() qm9.setup() # Build SchNet model pairwise_distance = spk.atomistic.PairwiseDistances() radial_basis = spk.nn.GaussianRBF(n_rbf=args.n_rbf, cutoff=args.cutoff) schnet = spk.representation.SchNet( n_atom_basis=args.n_features, n_interactions=args.n_interactions, radial_basis=radial_basis, cutoff_fn=spk.nn.CosineCutoff(args.cutoff), ) pred_head = spk.atomistic.Atomwise(n_in=args.n_features, output_key=spk_target) model = spk.model.NeuralNetworkPotential( representation=schnet, input_modules=[pairwise_distance], output_modules=[pred_head], postprocessors=[trn.CastTo64(), trn.AddOffsets(spk_target, add_mean=True)], ) # Note: schnetpack wraps each metric in nn.ModuleDict, so every entry # must be an nn.Module (lambdas are rejected). Use only L1Loss here; # we recompute RMSE/R² ourselves from saved predictions below. output = spk.task.ModelOutput( name=spk_target, loss_fn=torch.nn.MSELoss(), loss_weight=1.0, metrics={"MAE": torch.nn.L1Loss()}, ) task = spk.task.AtomisticTask( model=model, outputs=[output], optimizer_cls=torch.optim.Adam, optimizer_args={"lr": args.lr}, ) import pytorch_lightning as pl trainer = pl.Trainer( accelerator="gpu" if device.type == "cuda" else "cpu", devices=1, max_epochs=args.epochs, callbacks=[ pl.callbacks.EarlyStopping(monitor=f"val_{spk_target}_MAE", patience=20), pl.callbacks.ModelCheckpoint( monitor=f"val_{spk_target}_MAE", save_top_k=1, mode="min", dirpath=str(out_dir / f"ckpt_seed{args.seed}"), ), ], enable_progress_bar=True, log_every_n_steps=50, ) t0 = time.time() trainer.fit(task, datamodule=qm9) train_time = time.time() - t0 # Evaluate on test split test_results = trainer.test(task, datamodule=qm9, ckpt_path="best")[0] mae = float(test_results.get(f"test_{spk_target}_MAE", 0.0)) # R2 requires running predictions; do a final pass task.eval() preds, labels = [], [] for batch in qm9.test_dataloader(): batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} with torch.no_grad(): out = task(batch) preds.append(out[spk_target].cpu().numpy()) labels.append(batch[spk_target].cpu().numpy()) preds = np.concatenate(preds); labels = np.concatenate(labels) ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum() r2 = float(1 - ss_res / ss_tot) rmse = float(np.sqrt(((labels - preds) ** 2).mean())) result = { "method": "schnet", "target": args.target, "seed": args.seed, "n_train": args.n_train, "n_val": args.n_val, "n_features": args.n_features, "n_interactions": args.n_interactions, "n_params": sum(p.numel() for p in model.parameters()), "train_time_s": train_time, "r2": r2, "rmse": rmse, "mae": mae, "predictions": preds.tolist(), "labels": labels.tolist(), } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] schnet R2={r2:.4f} MAE={mae:.4g} -> {out_file}") if __name__ == "__main__": main()