"""SchNet baseline using PyTorch Geometric's reference implementation. Backup path for environments where schnetpack's data layer hits the sqlite-key bug. This trainer uses the PyG QM9 dataset directly (no sqlite, no schnetpack), and PyG's `torch_geometric.nn.models.SchNet` which is a clean reference port of Schütt et al. 2017. Usage: python train_baseline_pyg_schnet.py --target gap \ --pyg_root /u/$USER/data/qm9_pyg --seed 0 --out_dir results/ Output schema matches train_baseline_schnet.py / train_starg.py. """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch import torch.nn.functional as F try: from torch_geometric.data import Data from torch_geometric.datasets import QM9 from torch_geometric.loader import DataLoader from torch_geometric.nn.models import SchNet PYG_AVAILABLE = True except ImportError: PYG_AVAILABLE = False # PyG QM9 column → property index mapping. # (PyG's QM9 stores 19 properties per molecule; see the dataset docstring.) PYG_TARGET_INDEX = { "mu": 0, "alpha": 1, "homo": 2, "lumo": 3, "gap": 4, "r2": 5, "zpve": 6, "U0": 7, "U": 8, "H": 9, "G": 10, "Cv": 11, } def main(): if not PYG_AVAILABLE: raise RuntimeError( "torch_geometric is not installed. " "pip install torch_geometric" ) ap = argparse.ArgumentParser() ap.add_argument("--target", required=True) ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9") ap.add_argument("--pyg_root", help="PyG QM9 dataset root (only when --dataset qm9)") ap.add_argument("--qm7x_dir", help="QM7-X HDF5 directory (only when --dataset qm7x)") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--cutoff", type=float, default=10.0) ap.add_argument("--n_features", type=int, default=128) ap.add_argument("--n_interactions", type=int, default=6) ap.add_argument("--n_gaussians", type=int, default=50) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--batch", type=int, default=64) ap.add_argument("--patience", type=int, default=20) ap.add_argument("--n_train", type=int, default=110000, help="train-set size; only used for QM9 (qm7x uses 60/20/20)") ap.add_argument("--n_val", type=int, default=10000, help="val-set size; only used for QM9 (qm7x uses 60/20/20)") ap.add_argument("--out_dir", default="results/") ap.add_argument("--device", default="cuda") args = ap.parse_args() if args.dataset == "qm9" and not args.pyg_root: raise SystemExit("--pyg_root is required when --dataset qm9") if args.dataset == "qm7x" and not args.qm7x_dir: raise SystemExit("--qm7x_dir is required when --dataset qm7x") torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") if args.dataset == "qm9": out_dir = Path(args.out_dir) / "schnet" / args.target else: out_dir = Path(args.out_dir) / "schnet" / "qm7x" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" if args.dataset == "qm9": if args.target not in PYG_TARGET_INDEX: raise ValueError(f"unknown qm9 target {args.target}; " f"valid: {list(PYG_TARGET_INDEX)}") tgt_idx = PYG_TARGET_INDEX[args.target] dataset = QM9(args.pyg_root) n_total = len(dataset) print(f"[load] PyG QM9: {n_total} molecules") rng = np.random.default_rng(args.seed) perm = rng.permutation(n_total) train_idx = perm[: args.n_train] val_idx = perm[args.n_train : args.n_train + args.n_val] test_idx = perm[args.n_train + args.n_val :] train_set = dataset[train_idx.tolist()] val_set = dataset[val_idx.tolist()] test_set = dataset[test_idx.tolist()] train_y = torch.stack( [d.y[:, tgt_idx].squeeze() for d in train_set] ).float() y_mean = float(train_y.mean()) y_std = float(train_y.std() + 1e-8) else: # --- QM7-X branch ---------------------------------------------------- from data.qm7x import ( load_qm7x_equilibrium, qm7x_split, qm7x_target_array, ) if args.target not in ("alpha_iso", "alpha_E", "alpha_T2"): raise ValueError(f"unknown qm7x target {args.target}; " f"valid: alpha_iso | alpha_E | alpha_T2") samples = load_qm7x_equilibrium(args.qm7x_dir) n_total = len(samples) print(f"[load] QM7-X equilibrium: {n_total} molecules") targets = qm7x_target_array(samples, args.target) train_idx, val_idx, test_idx = qm7x_split(n_total, seed=args.seed) def _to_pyg_data(s, y_val): return Data( z=torch.tensor(s.Z, dtype=torch.long), pos=torch.tensor(s.coords, dtype=torch.float32), y=torch.tensor([float(y_val)], dtype=torch.float32), ) # Track tgt_idx for shared step() function below; for QM7-X each # Data has a 1-element y, so use tgt_idx=0 (handled in step()). tgt_idx = None train_set = [_to_pyg_data(samples[i], targets[i]) for i in train_idx] val_set = [_to_pyg_data(samples[i], targets[i]) for i in val_idx] test_set = [_to_pyg_data(samples[i], targets[i]) for i in test_idx] train_y = torch.tensor([float(targets[i]) for i in train_idx]) y_mean = float(train_y.mean()) y_std = float(train_y.std() + 1e-8) print(f"[norm] target={args.target} mean={y_mean:.4g} std={y_std:.4g}") train_loader = DataLoader(train_set, batch_size=args.batch, shuffle=True, num_workers=2) val_loader = DataLoader(val_set, batch_size=args.batch, shuffle=False, num_workers=2) test_loader = DataLoader(test_set, batch_size=args.batch, shuffle=False, num_workers=2) model = SchNet( hidden_channels=args.n_features, num_filters=args.n_features, num_interactions=args.n_interactions, num_gaussians=args.n_gaussians, cutoff=args.cutoff, ).to(device) print(f"[model] SchNet params: " f"{sum(p.numel() for p in model.parameters()):,}") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=0.5, patience=10, min_lr=1e-6, ) def step(loader, train: bool): if train: model.train() else: model.eval() total_loss = 0.0 total_n = 0 all_pred = [] all_true = [] for batch in loader: batch = batch.to(device) # QM9 batches carry a 19-column y (PyG QM9 layout). QM7-X batches # carry a 1-column y by construction. tgt_idx is None for QM7-X. if tgt_idx is None: target = batch.y.view(-1).float() else: target = batch.y[:, tgt_idx].view(-1).float() target_norm = (target - y_mean) / y_std if train: optimizer.zero_grad() with torch.set_grad_enabled(train): pred_norm = model(batch.z, batch.pos, batch.batch).view(-1) loss = F.mse_loss(pred_norm, target_norm) if train: loss.backward() optimizer.step() total_loss += float(loss.item()) * target.numel() total_n += target.numel() if not train: pred = pred_norm.detach().cpu().numpy() * y_std + y_mean all_pred.append(pred) all_true.append(target.detach().cpu().numpy()) if not train: return (total_loss / total_n, np.concatenate(all_pred), np.concatenate(all_true)) return total_loss / total_n t0 = time.time() best_val = float("inf") best_state = None wait = 0 for epoch in range(args.epochs): train_loss = step(train_loader, train=True) val_loss, _, _ = step(val_loader, train=False) scheduler.step(val_loss) if val_loss < best_val: best_val = val_loss best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} wait = 0 else: wait += 1 if wait >= args.patience: print(f"[early-stop] epoch {epoch}, best val={best_val:.4g}") break if epoch % 5 == 0: print(f"[ep {epoch:3d}] train={train_loss:.4g} " f"val={val_loss:.4g} best={best_val:.4g}") train_time = time.time() - t0 if best_state is not None: model.load_state_dict(best_state) _, preds, labels = step(test_loader, train=False) ss_res = float(np.sum((labels - preds) ** 2)) ss_tot = float(np.sum((labels - labels.mean()) ** 2)) r2 = 1.0 - ss_res / max(ss_tot, 1e-12) rmse = float(np.sqrt(((labels - preds) ** 2).mean())) mae = float(np.abs(labels - preds).mean()) result = { "method": "schnet", "dataset": args.dataset, "target": args.target, "seed": args.seed, "n_train": int(len(train_idx)), "n_val": int(len(val_idx)), "n_test": int(len(test_idx)), "n_features": args.n_features, "n_interactions": args.n_interactions, "n_params": int(sum(p.numel() for p in model.parameters())), "train_time_s": float(train_time), "r2": float(r2), "rmse": rmse, "mae": mae, "predictions": preds.tolist(), "labels": labels.tolist(), "test_idx": test_idx.tolist(), "implementation": "torch_geometric.nn.models.SchNet", } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] PyG-SchNet R2={r2:.4f} MAE={mae:.4g} -> {out_file}") if __name__ == "__main__": main()