tensor-group-sym / python / large_scale / train_baseline_pyg_schnet.py
train_baseline_pyg_schnet.py
Raw
"""SchNet baseline using PyTorch Geometric's reference implementation.

Backup path for environments where schnetpack's data layer hits the
sqlite-key bug. This trainer uses the PyG QM9 dataset directly (no
sqlite, no schnetpack), and PyG's `torch_geometric.nn.models.SchNet`
which is a clean reference port of Schütt et al. 2017.

Usage:
    python train_baseline_pyg_schnet.py --target gap \
        --pyg_root /u/$USER/data/qm9_pyg --seed 0 --out_dir results/

Output schema matches train_baseline_schnet.py / train_starg.py.
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

try:
    from torch_geometric.data import Data
    from torch_geometric.datasets import QM9
    from torch_geometric.loader import DataLoader
    from torch_geometric.nn.models import SchNet
    PYG_AVAILABLE = True
except ImportError:
    PYG_AVAILABLE = False


# PyG QM9 column → property index mapping.
# (PyG's QM9 stores 19 properties per molecule; see the dataset docstring.)
PYG_TARGET_INDEX = {
    "mu": 0,
    "alpha": 1,
    "homo": 2,
    "lumo": 3,
    "gap": 4,
    "r2": 5,
    "zpve": 6,
    "U0": 7,
    "U": 8,
    "H": 9,
    "G": 10,
    "Cv": 11,
}


def main():
    if not PYG_AVAILABLE:
        raise RuntimeError(
            "torch_geometric is not installed. "
            "pip install torch_geometric"
        )

    ap = argparse.ArgumentParser()
    ap.add_argument("--target", required=True)
    ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
    ap.add_argument("--pyg_root",
                    help="PyG QM9 dataset root (only when --dataset qm9)")
    ap.add_argument("--qm7x_dir",
                    help="QM7-X HDF5 directory (only when --dataset qm7x)")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--cutoff", type=float, default=10.0)
    ap.add_argument("--n_features", type=int, default=128)
    ap.add_argument("--n_interactions", type=int, default=6)
    ap.add_argument("--n_gaussians", type=int, default=50)
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--lr", type=float, default=5e-4)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--patience", type=int, default=20)
    ap.add_argument("--n_train", type=int, default=110000,
                    help="train-set size; only used for QM9 (qm7x uses 60/20/20)")
    ap.add_argument("--n_val", type=int, default=10000,
                    help="val-set size; only used for QM9 (qm7x uses 60/20/20)")
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()

    if args.dataset == "qm9" and not args.pyg_root:
        raise SystemExit("--pyg_root is required when --dataset qm9")
    if args.dataset == "qm7x" and not args.qm7x_dir:
        raise SystemExit("--qm7x_dir is required when --dataset qm7x")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if args.dataset == "qm9":
        out_dir = Path(args.out_dir) / "schnet" / args.target
    else:
        out_dir = Path(args.out_dir) / "schnet" / "qm7x" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    if args.dataset == "qm9":
        if args.target not in PYG_TARGET_INDEX:
            raise ValueError(f"unknown qm9 target {args.target}; "
                             f"valid: {list(PYG_TARGET_INDEX)}")
        tgt_idx = PYG_TARGET_INDEX[args.target]

        dataset = QM9(args.pyg_root)
        n_total = len(dataset)
        print(f"[load] PyG QM9: {n_total} molecules")

        rng = np.random.default_rng(args.seed)
        perm = rng.permutation(n_total)
        train_idx = perm[: args.n_train]
        val_idx = perm[args.n_train : args.n_train + args.n_val]
        test_idx = perm[args.n_train + args.n_val :]

        train_set = dataset[train_idx.tolist()]
        val_set = dataset[val_idx.tolist()]
        test_set = dataset[test_idx.tolist()]

        train_y = torch.stack(
            [d.y[:, tgt_idx].squeeze() for d in train_set]
        ).float()
        y_mean = float(train_y.mean())
        y_std = float(train_y.std() + 1e-8)
    else:
        # --- QM7-X branch ----------------------------------------------------
        from data.qm7x import (
            load_qm7x_equilibrium,
            qm7x_split,
            qm7x_target_array,
        )
        if args.target not in ("alpha_iso", "alpha_E", "alpha_T2"):
            raise ValueError(f"unknown qm7x target {args.target}; "
                             f"valid: alpha_iso | alpha_E | alpha_T2")
        samples = load_qm7x_equilibrium(args.qm7x_dir)
        n_total = len(samples)
        print(f"[load] QM7-X equilibrium: {n_total} molecules")
        targets = qm7x_target_array(samples, args.target)
        train_idx, val_idx, test_idx = qm7x_split(n_total, seed=args.seed)

        def _to_pyg_data(s, y_val):
            return Data(
                z=torch.tensor(s.Z, dtype=torch.long),
                pos=torch.tensor(s.coords, dtype=torch.float32),
                y=torch.tensor([float(y_val)], dtype=torch.float32),
            )
        # Track tgt_idx for shared step() function below; for QM7-X each
        # Data has a 1-element y, so use tgt_idx=0 (handled in step()).
        tgt_idx = None

        train_set = [_to_pyg_data(samples[i], targets[i]) for i in train_idx]
        val_set = [_to_pyg_data(samples[i], targets[i]) for i in val_idx]
        test_set = [_to_pyg_data(samples[i], targets[i]) for i in test_idx]

        train_y = torch.tensor([float(targets[i]) for i in train_idx])
        y_mean = float(train_y.mean())
        y_std = float(train_y.std() + 1e-8)

    print(f"[norm] target={args.target} mean={y_mean:.4g} std={y_std:.4g}")

    train_loader = DataLoader(train_set, batch_size=args.batch,
                              shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, batch_size=args.batch,
                            shuffle=False, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=args.batch,
                             shuffle=False, num_workers=2)

    model = SchNet(
        hidden_channels=args.n_features,
        num_filters=args.n_features,
        num_interactions=args.n_interactions,
        num_gaussians=args.n_gaussians,
        cutoff=args.cutoff,
    ).to(device)
    print(f"[model] SchNet params: "
          f"{sum(p.numel() for p in model.parameters()):,}")

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=10, min_lr=1e-6,
    )

    def step(loader, train: bool):
        if train:
            model.train()
        else:
            model.eval()
        total_loss = 0.0
        total_n = 0
        all_pred = []
        all_true = []
        for batch in loader:
            batch = batch.to(device)
            # QM9 batches carry a 19-column y (PyG QM9 layout). QM7-X batches
            # carry a 1-column y by construction. tgt_idx is None for QM7-X.
            if tgt_idx is None:
                target = batch.y.view(-1).float()
            else:
                target = batch.y[:, tgt_idx].view(-1).float()
            target_norm = (target - y_mean) / y_std
            if train:
                optimizer.zero_grad()
            with torch.set_grad_enabled(train):
                pred_norm = model(batch.z, batch.pos, batch.batch).view(-1)
                loss = F.mse_loss(pred_norm, target_norm)
            if train:
                loss.backward()
                optimizer.step()
            total_loss += float(loss.item()) * target.numel()
            total_n += target.numel()
            if not train:
                pred = pred_norm.detach().cpu().numpy() * y_std + y_mean
                all_pred.append(pred)
                all_true.append(target.detach().cpu().numpy())
        if not train:
            return (total_loss / total_n,
                    np.concatenate(all_pred), np.concatenate(all_true))
        return total_loss / total_n

    t0 = time.time()
    best_val = float("inf")
    best_state = None
    wait = 0
    for epoch in range(args.epochs):
        train_loss = step(train_loader, train=True)
        val_loss, _, _ = step(val_loader, train=False)
        scheduler.step(val_loss)
        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= args.patience:
                print(f"[early-stop] epoch {epoch}, best val={best_val:.4g}")
                break
        if epoch % 5 == 0:
            print(f"[ep {epoch:3d}] train={train_loss:.4g} "
                  f"val={val_loss:.4g} best={best_val:.4g}")
    train_time = time.time() - t0

    if best_state is not None:
        model.load_state_dict(best_state)

    _, preds, labels = step(test_loader, train=False)
    ss_res = float(np.sum((labels - preds) ** 2))
    ss_tot = float(np.sum((labels - labels.mean()) ** 2))
    r2 = 1.0 - ss_res / max(ss_tot, 1e-12)
    rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
    mae = float(np.abs(labels - preds).mean())

    result = {
        "method": "schnet",
        "dataset": args.dataset,
        "target": args.target,
        "seed": args.seed,
        "n_train": int(len(train_idx)),
        "n_val": int(len(val_idx)),
        "n_test": int(len(test_idx)),
        "n_features": args.n_features,
        "n_interactions": args.n_interactions,
        "n_params": int(sum(p.numel() for p in model.parameters())),
        "train_time_s": float(train_time),
        "r2": float(r2),
        "rmse": rmse,
        "mae": mae,
        "predictions": preds.tolist(),
        "labels": labels.tolist(),
        "test_idx": test_idx.tolist(),
        "implementation": "torch_geometric.nn.models.SchNet",
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] PyG-SchNet R2={r2:.4f}  MAE={mae:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()