"""SchNet baseline (invariant ENN).
SchNet (Schütt et al., 2017) is the standard invariant-ENN baseline on QM9.
It builds a continuous-filter convolutional network on atomic graphs, uses
distance-based features only (no angular/equivariant features), and predicts
scalar properties. We use the schnetpack reference implementation pinned to
v2.1.1 in requirements.txt.
Usage:
python train_baseline_schnet.py --target gap --qm9_dir /path/to/qm9 \
--seed 0 --out_dir results/
This script writes results/schnet/<target>/seed<k>.json with the same schema
as the ★_G and MLP scripts.
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
try:
import schnetpack as spk
import schnetpack.transform as trn
from torch.utils.data import DataLoader
SCHNET_AVAILABLE = True
except ImportError:
SCHNET_AVAILABLE = False
SCHNET_TARGET_KEYS = {
"mu": "dipole_moment",
"alpha": "isotropic_polarizability",
"homo": "homo",
"lumo": "lumo",
"gap": "gap",
"r2": "electronic_spatial_extent",
"zpve": "zpve",
"U0": "energy_U0",
"U": "energy_U",
"H": "enthalpy_H",
"G": "free_energy",
"Cv": "heat_capacity",
}
def main():
if not SCHNET_AVAILABLE:
raise RuntimeError(
"schnetpack is not installed. Run: pip install schnetpack==2.1.1"
)
ap = argparse.ArgumentParser()
ap.add_argument("--target", required=True, help="QM9 property name")
ap.add_argument("--qm9_dir", required=True,
help="path to QM9 raw directory (or schnetpack cache root)")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--cutoff", type=float, default=5.0)
ap.add_argument("--n_features", type=int, default=128)
ap.add_argument("--n_interactions", type=int, default=6)
ap.add_argument("--n_rbf", type=int, default=20)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=5e-4)
ap.add_argument("--batch", type=int, default=64)
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--n_train", type=int, default=110000)
ap.add_argument("--n_val", type=int, default=10000)
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
out_dir = Path(args.out_dir) / "schnet" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
spk_target = SCHNET_TARGET_KEYS[args.target]
# schnetpack QM9 dataset wrapper handles the standard 110k/10k/13k split.
# remove_uncharacterized=True matches the 130,831-molecule subset that
# PyG's QM9 v3 also uses (3054 uncharacterized molecules from
# Ramakrishnan 2014 are excluded for consistency).
qm9 = spk.datasets.QM9(
datapath=str(Path(args.qm9_dir) / f"qm9_seed{args.seed}.db"),
batch_size=args.batch,
num_train=args.n_train,
num_val=args.n_val,
remove_uncharacterized=True,
transforms=[
trn.ASENeighborList(cutoff=args.cutoff),
trn.RemoveOffsets(spk_target, remove_mean=True, remove_atomrefs=False),
trn.CastTo32(),
],
property_units={spk_target: "Ha" if args.target in ("homo", "lumo", "gap", "zpve", "U0", "U", "H", "G") else None},
num_workers=4,
pin_memory=True,
)
qm9.prepare_data()
qm9.setup()
# Build SchNet model
pairwise_distance = spk.atomistic.PairwiseDistances()
radial_basis = spk.nn.GaussianRBF(n_rbf=args.n_rbf, cutoff=args.cutoff)
schnet = spk.representation.SchNet(
n_atom_basis=args.n_features,
n_interactions=args.n_interactions,
radial_basis=radial_basis,
cutoff_fn=spk.nn.CosineCutoff(args.cutoff),
)
pred_head = spk.atomistic.Atomwise(n_in=args.n_features, output_key=spk_target)
model = spk.model.NeuralNetworkPotential(
representation=schnet,
input_modules=[pairwise_distance],
output_modules=[pred_head],
postprocessors=[trn.CastTo64(), trn.AddOffsets(spk_target, add_mean=True)],
)
# Note: schnetpack wraps each metric in nn.ModuleDict, so every entry
# must be an nn.Module (lambdas are rejected). Use only L1Loss here;
# we recompute RMSE/R² ourselves from saved predictions below.
output = spk.task.ModelOutput(
name=spk_target,
loss_fn=torch.nn.MSELoss(),
loss_weight=1.0,
metrics={"MAE": torch.nn.L1Loss()},
)
task = spk.task.AtomisticTask(
model=model,
outputs=[output],
optimizer_cls=torch.optim.Adam,
optimizer_args={"lr": args.lr},
)
import pytorch_lightning as pl
trainer = pl.Trainer(
accelerator="gpu" if device.type == "cuda" else "cpu",
devices=1,
max_epochs=args.epochs,
callbacks=[
pl.callbacks.EarlyStopping(monitor=f"val_{spk_target}_MAE", patience=20),
pl.callbacks.ModelCheckpoint(
monitor=f"val_{spk_target}_MAE", save_top_k=1, mode="min",
dirpath=str(out_dir / f"ckpt_seed{args.seed}"),
),
],
enable_progress_bar=True,
log_every_n_steps=50,
)
t0 = time.time()
trainer.fit(task, datamodule=qm9)
train_time = time.time() - t0
# Evaluate on test split
test_results = trainer.test(task, datamodule=qm9, ckpt_path="best")[0]
mae = float(test_results.get(f"test_{spk_target}_MAE", 0.0))
# R2 requires running predictions; do a final pass
task.eval()
preds, labels = [], []
for batch in qm9.test_dataloader():
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
with torch.no_grad():
out = task(batch)
preds.append(out[spk_target].cpu().numpy())
labels.append(batch[spk_target].cpu().numpy())
preds = np.concatenate(preds); labels = np.concatenate(labels)
ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
r2 = float(1 - ss_res / ss_tot)
rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
result = {
"method": "schnet",
"target": args.target,
"seed": args.seed,
"n_train": args.n_train,
"n_val": args.n_val,
"n_features": args.n_features,
"n_interactions": args.n_interactions,
"n_params": sum(p.numel() for p in model.parameters()),
"train_time_s": train_time,
"r2": r2, "rmse": rmse, "mae": mae,
"predictions": preds.tolist(),
"labels": labels.tolist(),
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] schnet R2={r2:.4f} MAE={mae:.4g} -> {out_file}")
if __name__ == "__main__":
main()