"""MACE baseline (current SOTA on QM9-class molecular property prediction). MACE (Batatia et al., 2022) is the state-of-the-art equivariant message- passing neural network for molecular property prediction, leveraging higher- order equivariant tensor products via the Atomic Cluster Expansion (ACE) formalism. We use the official mace-torch package (v0.3.6 pinned). Usage: python train_baseline_mace.py --target gap --qm9_dir /path/to/qm9 \ --seed 0 --out_dir results/ Note: a typical MACE-small training run on full QM9 takes ~6 hours per seed on a single H100. For molecular polarizability tensors, MACE outputs the appropriate irreps via its built-in tensor-output head. """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch try: from mace import data, modules, tools from mace.tools import torch_geometric from mace.calculators.foundations_models import mace_off # Newer mace-torch (>= 0.3.10) requires explicit interaction-block # classes for ScaleShiftMACE; older versions defaulted them silently. from mace.modules import ( RealAgnosticInteractionBlock, RealAgnosticResidualInteractionBlock, ) # mace-torch >= 0.3.10 requires Irreps objects (not strings) for the # hidden_irreps and MLP_irreps arguments; older versions accepted both. from e3nn import o3 MACE_AVAILABLE = True except ImportError: MACE_AVAILABLE = False from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX from data.dataset_adapter import ( load_samples, build_target, split_indices, element_set, ) def _to_mace_atomic_data(sample, z_table, cutoff: float = 5.0): """Convert a QM9Sample to a MACE AtomicData object. The z_table must be built once at the call site from the *full* element set (H/C/N/O/F for QM9). Building it per-molecule from `config.atomic_numbers` produces variable-width `node_attrs` and `Batch.from_data_list` then refuses to concatenate molecules with different element subsets. """ from mace.data import AtomicData, Configuration # AtomicData requires *some* energy in the Configuration to satisfy # mace-torch's schema; the value is overridden per-sample after # construction in main(). For QM9 samples we use a non-zero placeholder # (sample.properties[4]) to avoid all-zero degenerate normalization # paths inside the AtomicData constructor; for QM7-X samples (no # `properties` attr) the placeholder is 0.0. placeholder_energy = 0.0 if hasattr(sample, "properties"): try: placeholder_energy = float(sample.properties[4]) except (IndexError, TypeError, ValueError): placeholder_energy = 0.0 config = Configuration( atomic_numbers=np.asarray(sample.Z, dtype=int), positions=np.asarray(sample.coords, dtype=np.float64), properties={"energy": placeholder_energy}, property_weights={"energy": 1.0}, ) return AtomicData.from_config(config, z_table=z_table, cutoff=cutoff) def main(): if not MACE_AVAILABLE: raise RuntimeError("mace-torch is not installed. pip install mace-torch==0.3.6") ap = argparse.ArgumentParser() ap.add_argument("--target", required=True) ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9") ap.add_argument("--qm9_dir", help="QM9 .xyz directory; required when --dataset qm9") ap.add_argument("--qm7x_dir", help="QM7-X HDF5 directory; required when --dataset qm7x") ap.add_argument("--max_molecules", type=int, default=None) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=1e-3) ap.add_argument("--batch", type=int, default=32) ap.add_argument("--cutoff", type=float, default=5.0) ap.add_argument("--max_ell", type=int, default=3) ap.add_argument("--correlation", type=int, default=3) ap.add_argument("--n_features", type=int, default=128) ap.add_argument("--out_dir", default="results/") ap.add_argument("--device", default="cuda") args = ap.parse_args() if args.dataset == "qm9" and not args.qm9_dir: raise SystemExit("--qm9_dir is required when --dataset qm9") if args.dataset == "qm7x" and not args.qm7x_dir: raise SystemExit("--qm7x_dir is required when --dataset qm7x") data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") # Output layout: results/mace//seed.json for QM9 (preserved # for backwards compat); results/mace/qm7x//seed.json for # QM7-X to keep the two datasets cleanly separated in disk + audits. if args.dataset == "qm9": out_dir = Path(args.out_dir) / "mace" / args.target else: out_dir = Path(args.out_dir) / "mace" / "qm7x" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" samples = load_samples(args.dataset, data_dir, max_molecules=args.max_molecules) n_total = len(samples) train_idx, val_idx, test_idx = split_indices(args.dataset, n_total, seed=args.seed) y = build_target(args.dataset, samples, args.target) y_mean = float(y.mean()); y_std = float(y.std() + 1e-8) z_table = tools.get_atomic_number_table_from_zs(element_set(args.dataset)) n_elements = len(z_table.zs) # Build MACE configuration atomic_energies = np.zeros(n_elements) model = modules.ScaleShiftMACE( r_max=args.cutoff, num_bessel=8, num_polynomial_cutoff=5, max_ell=args.max_ell, correlation=args.correlation, num_interactions=2, num_elements=n_elements, hidden_irreps=o3.Irreps("128x0e + 128x1o") if args.n_features >= 128 else o3.Irreps("64x0e + 32x1o"), MLP_irreps=o3.Irreps("16x0e"), atomic_energies=atomic_energies, avg_num_neighbors=8.0, atomic_numbers=z_table.zs, gate=torch.nn.functional.silu, interaction_cls=RealAgnosticResidualInteractionBlock, interaction_cls_first=RealAgnosticInteractionBlock, atomic_inter_scale=y_std, atomic_inter_shift=y_mean, ).to(device) train_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in train_idx] val_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in val_idx] test_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in test_idx] # Override the per-sample target after AtomicData construction for i, k in enumerate(train_idx): train_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64) for i, k in enumerate(val_idx): val_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64) for i, k in enumerate(test_idx): test_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64) train_loader = torch_geometric.dataloader.DataLoader( dataset=train_data, batch_size=args.batch, shuffle=True, drop_last=False, num_workers=4, ) val_loader = torch_geometric.dataloader.DataLoader( dataset=val_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4, ) test_loader = torch_geometric.dataloader.DataLoader( dataset=test_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4, ) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=15) loss_fn = torch.nn.MSELoss() best_val, best_state, wait = float("inf"), None, 0 t0 = time.time() for ep in range(args.epochs): model.train() for batch in train_loader: batch = batch.to(device) optimizer.zero_grad() out = model(batch, training=True, compute_force=False) target = batch.energy.to(out["energy"].dtype) loss = loss_fn(out["energy"], target) loss.backward() optimizer.step() # Validate model.eval() with torch.no_grad(): val_loss = 0.0; n = 0 for batch in val_loader: batch = batch.to(device) out = model(batch, training=False, compute_force=False) target = batch.energy.to(out["energy"].dtype) val_loss += ((out["energy"] - target) ** 2).sum().item() n += target.numel() val_loss /= max(n, 1) scheduler.step(val_loss) if val_loss < best_val: best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0 else: wait += 1 if wait >= 25: break train_time = time.time() - t0 if best_state: model.load_state_dict(best_state) # Test model.eval() preds, labels = [], [] with torch.no_grad(): for batch in test_loader: batch = batch.to(device) out = model(batch, training=False, compute_force=False) preds.append(out["energy"].cpu().numpy()) labels.append(batch.energy.cpu().numpy()) preds = np.concatenate(preds); labels = np.concatenate(labels) ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum() r2 = float(1 - ss_res / ss_tot) rmse = float(np.sqrt(((labels - preds) ** 2).mean())) mae = float(np.abs(labels - preds).mean()) result = { "method": "mace", "target": args.target, "seed": args.seed, "dataset": args.dataset, "n_total": n_total, "max_ell": args.max_ell, "correlation": args.correlation, "n_params": sum(p.numel() for p in model.parameters()), "train_time_s": train_time, "r2": r2, "rmse": rmse, "mae": mae, "predictions": preds.reshape(-1).tolist(), "labels": labels.reshape(-1).tolist(), } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] MACE R2={r2:.4f} MAE={mae:.4g} -> {out_file}") if __name__ == "__main__": main()