tensor-group-sym / python / large_scale / train_baseline_mace.py
train_baseline_mace.py
Raw
"""MACE baseline (current SOTA on QM9-class molecular property prediction).

MACE (Batatia et al., 2022) is the state-of-the-art equivariant message-
passing neural network for molecular property prediction, leveraging higher-
order equivariant tensor products via the Atomic Cluster Expansion (ACE)
formalism. We use the official mace-torch package (v0.3.6 pinned).

Usage:
    python train_baseline_mace.py --target gap --qm9_dir /path/to/qm9 \
        --seed 0 --out_dir results/

Note: a typical MACE-small training run on full QM9 takes ~6 hours per seed
on a single H100. For molecular polarizability tensors, MACE outputs the
appropriate irreps via its built-in tensor-output head.
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch

try:
    from mace import data, modules, tools
    from mace.tools import torch_geometric
    from mace.calculators.foundations_models import mace_off
    # Newer mace-torch (>= 0.3.10) requires explicit interaction-block
    # classes for ScaleShiftMACE; older versions defaulted them silently.
    from mace.modules import (
        RealAgnosticInteractionBlock,
        RealAgnosticResidualInteractionBlock,
    )
    # mace-torch >= 0.3.10 requires Irreps objects (not strings) for the
    # hidden_irreps and MLP_irreps arguments; older versions accepted both.
    from e3nn import o3
    MACE_AVAILABLE = True
except ImportError:
    MACE_AVAILABLE = False

from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.dataset_adapter import (
    load_samples,
    build_target,
    split_indices,
    element_set,
)


def _to_mace_atomic_data(sample, z_table, cutoff: float = 5.0):
    """Convert a QM9Sample to a MACE AtomicData object.

    The z_table must be built once at the call site from the *full* element
    set (H/C/N/O/F for QM9). Building it per-molecule from
    `config.atomic_numbers` produces variable-width `node_attrs` and
    `Batch.from_data_list` then refuses to concatenate molecules with
    different element subsets.
    """
    from mace.data import AtomicData, Configuration
    # AtomicData requires *some* energy in the Configuration to satisfy
    # mace-torch's schema; the value is overridden per-sample after
    # construction in main(). For QM9 samples we use a non-zero placeholder
    # (sample.properties[4]) to avoid all-zero degenerate normalization
    # paths inside the AtomicData constructor; for QM7-X samples (no
    # `properties` attr) the placeholder is 0.0.
    placeholder_energy = 0.0
    if hasattr(sample, "properties"):
        try:
            placeholder_energy = float(sample.properties[4])
        except (IndexError, TypeError, ValueError):
            placeholder_energy = 0.0
    config = Configuration(
        atomic_numbers=np.asarray(sample.Z, dtype=int),
        positions=np.asarray(sample.coords, dtype=np.float64),
        properties={"energy": placeholder_energy},
        property_weights={"energy": 1.0},
    )
    return AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)


def main():
    if not MACE_AVAILABLE:
        raise RuntimeError("mace-torch is not installed. pip install mace-torch==0.3.6")

    ap = argparse.ArgumentParser()
    ap.add_argument("--target", required=True)
    ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
    ap.add_argument("--qm9_dir",
                    help="QM9 .xyz directory; required when --dataset qm9")
    ap.add_argument("--qm7x_dir",
                    help="QM7-X HDF5 directory; required when --dataset qm7x")
    ap.add_argument("--max_molecules", type=int, default=None)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--batch", type=int, default=32)
    ap.add_argument("--cutoff", type=float, default=5.0)
    ap.add_argument("--max_ell", type=int, default=3)
    ap.add_argument("--correlation", type=int, default=3)
    ap.add_argument("--n_features", type=int, default=128)
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()

    if args.dataset == "qm9" and not args.qm9_dir:
        raise SystemExit("--qm9_dir is required when --dataset qm9")
    if args.dataset == "qm7x" and not args.qm7x_dir:
        raise SystemExit("--qm7x_dir is required when --dataset qm7x")
    data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    # Output layout: results/mace/<target>/seed<k>.json for QM9 (preserved
    # for backwards compat); results/mace/qm7x/<target>/seed<k>.json for
    # QM7-X to keep the two datasets cleanly separated in disk + audits.
    if args.dataset == "qm9":
        out_dir = Path(args.out_dir) / "mace" / args.target
    else:
        out_dir = Path(args.out_dir) / "mace" / "qm7x" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    samples = load_samples(args.dataset, data_dir,
                           max_molecules=args.max_molecules)
    n_total = len(samples)
    train_idx, val_idx, test_idx = split_indices(args.dataset, n_total,
                                                 seed=args.seed)
    y = build_target(args.dataset, samples, args.target)
    y_mean = float(y.mean()); y_std = float(y.std() + 1e-8)

    z_table = tools.get_atomic_number_table_from_zs(element_set(args.dataset))
    n_elements = len(z_table.zs)

    # Build MACE configuration
    atomic_energies = np.zeros(n_elements)
    model = modules.ScaleShiftMACE(
        r_max=args.cutoff,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=args.max_ell,
        correlation=args.correlation,
        num_interactions=2,
        num_elements=n_elements,
        hidden_irreps=o3.Irreps("128x0e + 128x1o") if args.n_features >= 128 else o3.Irreps("64x0e + 32x1o"),
        MLP_irreps=o3.Irreps("16x0e"),
        atomic_energies=atomic_energies,
        avg_num_neighbors=8.0,
        atomic_numbers=z_table.zs,
        gate=torch.nn.functional.silu,
        interaction_cls=RealAgnosticResidualInteractionBlock,
        interaction_cls_first=RealAgnosticInteractionBlock,
        atomic_inter_scale=y_std,
        atomic_inter_shift=y_mean,
    ).to(device)

    train_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in train_idx]
    val_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in val_idx]
    test_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in test_idx]
    # Override the per-sample target after AtomicData construction
    for i, k in enumerate(train_idx):
        train_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)
    for i, k in enumerate(val_idx):
        val_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)
    for i, k in enumerate(test_idx):
        test_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)

    train_loader = torch_geometric.dataloader.DataLoader(
        dataset=train_data, batch_size=args.batch, shuffle=True, drop_last=False, num_workers=4,
    )
    val_loader = torch_geometric.dataloader.DataLoader(
        dataset=val_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4,
    )
    test_loader = torch_geometric.dataloader.DataLoader(
        dataset=test_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=15)
    loss_fn = torch.nn.MSELoss()

    best_val, best_state, wait = float("inf"), None, 0
    t0 = time.time()
    for ep in range(args.epochs):
        model.train()
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch, training=True, compute_force=False)
            target = batch.energy.to(out["energy"].dtype)
            loss = loss_fn(out["energy"], target)
            loss.backward()
            optimizer.step()
        # Validate
        model.eval()
        with torch.no_grad():
            val_loss = 0.0; n = 0
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch, training=False, compute_force=False)
                target = batch.energy.to(out["energy"].dtype)
                val_loss += ((out["energy"] - target) ** 2).sum().item()
                n += target.numel()
            val_loss /= max(n, 1)
        scheduler.step(val_loss)
        if val_loss < best_val:
            best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
        else:
            wait += 1
            if wait >= 25:
                break
    train_time = time.time() - t0
    if best_state:
        model.load_state_dict(best_state)

    # Test
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch, training=False, compute_force=False)
            preds.append(out["energy"].cpu().numpy())
            labels.append(batch.energy.cpu().numpy())
    preds = np.concatenate(preds); labels = np.concatenate(labels)
    ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
    r2 = float(1 - ss_res / ss_tot)
    rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
    mae = float(np.abs(labels - preds).mean())

    result = {
        "method": "mace",
        "target": args.target,
        "seed": args.seed,
        "dataset": args.dataset,
        "n_total": n_total,
        "max_ell": args.max_ell,
        "correlation": args.correlation,
        "n_params": sum(p.numel() for p in model.parameters()),
        "train_time_s": train_time,
        "r2": r2, "rmse": rmse, "mae": mae,
        "predictions": preds.reshape(-1).tolist(),
        "labels": labels.reshape(-1).tolist(),
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] MACE R2={r2:.4f}  MAE={mae:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()