"""SchNet baseline using PyTorch Geometric's reference implementation.
Backup path for environments where schnetpack's data layer hits the
sqlite-key bug. This trainer uses the PyG QM9 dataset directly (no
sqlite, no schnetpack), and PyG's `torch_geometric.nn.models.SchNet`
which is a clean reference port of Schütt et al. 2017.
Usage:
python train_baseline_pyg_schnet.py --target gap \
--pyg_root /u/$USER/data/qm9_pyg --seed 0 --out_dir results/
Output schema matches train_baseline_schnet.py / train_starg.py.
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
try:
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import SchNet
PYG_AVAILABLE = True
except ImportError:
PYG_AVAILABLE = False
# PyG QM9 column → property index mapping.
# (PyG's QM9 stores 19 properties per molecule; see the dataset docstring.)
PYG_TARGET_INDEX = {
"mu": 0,
"alpha": 1,
"homo": 2,
"lumo": 3,
"gap": 4,
"r2": 5,
"zpve": 6,
"U0": 7,
"U": 8,
"H": 9,
"G": 10,
"Cv": 11,
}
def main():
if not PYG_AVAILABLE:
raise RuntimeError(
"torch_geometric is not installed. "
"pip install torch_geometric"
)
ap = argparse.ArgumentParser()
ap.add_argument("--target", required=True)
ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
ap.add_argument("--pyg_root",
help="PyG QM9 dataset root (only when --dataset qm9)")
ap.add_argument("--qm7x_dir",
help="QM7-X HDF5 directory (only when --dataset qm7x)")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--cutoff", type=float, default=10.0)
ap.add_argument("--n_features", type=int, default=128)
ap.add_argument("--n_interactions", type=int, default=6)
ap.add_argument("--n_gaussians", type=int, default=50)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=5e-4)
ap.add_argument("--batch", type=int, default=64)
ap.add_argument("--patience", type=int, default=20)
ap.add_argument("--n_train", type=int, default=110000,
help="train-set size; only used for QM9 (qm7x uses 60/20/20)")
ap.add_argument("--n_val", type=int, default=10000,
help="val-set size; only used for QM9 (qm7x uses 60/20/20)")
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
if args.dataset == "qm9" and not args.pyg_root:
raise SystemExit("--pyg_root is required when --dataset qm9")
if args.dataset == "qm7x" and not args.qm7x_dir:
raise SystemExit("--qm7x_dir is required when --dataset qm7x")
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
if args.dataset == "qm9":
out_dir = Path(args.out_dir) / "schnet" / args.target
else:
out_dir = Path(args.out_dir) / "schnet" / "qm7x" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
if args.dataset == "qm9":
if args.target not in PYG_TARGET_INDEX:
raise ValueError(f"unknown qm9 target {args.target}; "
f"valid: {list(PYG_TARGET_INDEX)}")
tgt_idx = PYG_TARGET_INDEX[args.target]
dataset = QM9(args.pyg_root)
n_total = len(dataset)
print(f"[load] PyG QM9: {n_total} molecules")
rng = np.random.default_rng(args.seed)
perm = rng.permutation(n_total)
train_idx = perm[: args.n_train]
val_idx = perm[args.n_train : args.n_train + args.n_val]
test_idx = perm[args.n_train + args.n_val :]
train_set = dataset[train_idx.tolist()]
val_set = dataset[val_idx.tolist()]
test_set = dataset[test_idx.tolist()]
train_y = torch.stack(
[d.y[:, tgt_idx].squeeze() for d in train_set]
).float()
y_mean = float(train_y.mean())
y_std = float(train_y.std() + 1e-8)
else:
# --- QM7-X branch ----------------------------------------------------
from data.qm7x import (
load_qm7x_equilibrium,
qm7x_split,
qm7x_target_array,
)
if args.target not in ("alpha_iso", "alpha_E", "alpha_T2"):
raise ValueError(f"unknown qm7x target {args.target}; "
f"valid: alpha_iso | alpha_E | alpha_T2")
samples = load_qm7x_equilibrium(args.qm7x_dir)
n_total = len(samples)
print(f"[load] QM7-X equilibrium: {n_total} molecules")
targets = qm7x_target_array(samples, args.target)
train_idx, val_idx, test_idx = qm7x_split(n_total, seed=args.seed)
def _to_pyg_data(s, y_val):
return Data(
z=torch.tensor(s.Z, dtype=torch.long),
pos=torch.tensor(s.coords, dtype=torch.float32),
y=torch.tensor([float(y_val)], dtype=torch.float32),
)
# Track tgt_idx for shared step() function below; for QM7-X each
# Data has a 1-element y, so use tgt_idx=0 (handled in step()).
tgt_idx = None
train_set = [_to_pyg_data(samples[i], targets[i]) for i in train_idx]
val_set = [_to_pyg_data(samples[i], targets[i]) for i in val_idx]
test_set = [_to_pyg_data(samples[i], targets[i]) for i in test_idx]
train_y = torch.tensor([float(targets[i]) for i in train_idx])
y_mean = float(train_y.mean())
y_std = float(train_y.std() + 1e-8)
print(f"[norm] target={args.target} mean={y_mean:.4g} std={y_std:.4g}")
train_loader = DataLoader(train_set, batch_size=args.batch,
shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=args.batch,
shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=args.batch,
shuffle=False, num_workers=2)
model = SchNet(
hidden_channels=args.n_features,
num_filters=args.n_features,
num_interactions=args.n_interactions,
num_gaussians=args.n_gaussians,
cutoff=args.cutoff,
).to(device)
print(f"[model] SchNet params: "
f"{sum(p.numel() for p in model.parameters()):,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.5, patience=10, min_lr=1e-6,
)
def step(loader, train: bool):
if train:
model.train()
else:
model.eval()
total_loss = 0.0
total_n = 0
all_pred = []
all_true = []
for batch in loader:
batch = batch.to(device)
# QM9 batches carry a 19-column y (PyG QM9 layout). QM7-X batches
# carry a 1-column y by construction. tgt_idx is None for QM7-X.
if tgt_idx is None:
target = batch.y.view(-1).float()
else:
target = batch.y[:, tgt_idx].view(-1).float()
target_norm = (target - y_mean) / y_std
if train:
optimizer.zero_grad()
with torch.set_grad_enabled(train):
pred_norm = model(batch.z, batch.pos, batch.batch).view(-1)
loss = F.mse_loss(pred_norm, target_norm)
if train:
loss.backward()
optimizer.step()
total_loss += float(loss.item()) * target.numel()
total_n += target.numel()
if not train:
pred = pred_norm.detach().cpu().numpy() * y_std + y_mean
all_pred.append(pred)
all_true.append(target.detach().cpu().numpy())
if not train:
return (total_loss / total_n,
np.concatenate(all_pred), np.concatenate(all_true))
return total_loss / total_n
t0 = time.time()
best_val = float("inf")
best_state = None
wait = 0
for epoch in range(args.epochs):
train_loss = step(train_loader, train=True)
val_loss, _, _ = step(val_loader, train=False)
scheduler.step(val_loss)
if val_loss < best_val:
best_val = val_loss
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
wait = 0
else:
wait += 1
if wait >= args.patience:
print(f"[early-stop] epoch {epoch}, best val={best_val:.4g}")
break
if epoch % 5 == 0:
print(f"[ep {epoch:3d}] train={train_loss:.4g} "
f"val={val_loss:.4g} best={best_val:.4g}")
train_time = time.time() - t0
if best_state is not None:
model.load_state_dict(best_state)
_, preds, labels = step(test_loader, train=False)
ss_res = float(np.sum((labels - preds) ** 2))
ss_tot = float(np.sum((labels - labels.mean()) ** 2))
r2 = 1.0 - ss_res / max(ss_tot, 1e-12)
rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
mae = float(np.abs(labels - preds).mean())
result = {
"method": "schnet",
"dataset": args.dataset,
"target": args.target,
"seed": args.seed,
"n_train": int(len(train_idx)),
"n_val": int(len(val_idx)),
"n_test": int(len(test_idx)),
"n_features": args.n_features,
"n_interactions": args.n_interactions,
"n_params": int(sum(p.numel() for p in model.parameters())),
"train_time_s": float(train_time),
"r2": float(r2),
"rmse": rmse,
"mae": mae,
"predictions": preds.tolist(),
"labels": labels.tolist(),
"test_idx": test_idx.tolist(),
"implementation": "torch_geometric.nn.models.SchNet",
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] PyG-SchNet R2={r2:.4f} MAE={mae:.4g} -> {out_file}")
if __name__ == "__main__":
main()