"""MACE baseline (current SOTA on QM9-class molecular property prediction).
MACE (Batatia et al., 2022) is the state-of-the-art equivariant message-
passing neural network for molecular property prediction, leveraging higher-
order equivariant tensor products via the Atomic Cluster Expansion (ACE)
formalism. We use the official mace-torch package (v0.3.6 pinned).
Usage:
python train_baseline_mace.py --target gap --qm9_dir /path/to/qm9 \
--seed 0 --out_dir results/
Note: a typical MACE-small training run on full QM9 takes ~6 hours per seed
on a single H100. For molecular polarizability tensors, MACE outputs the
appropriate irreps via its built-in tensor-output head.
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
try:
from mace import data, modules, tools
from mace.tools import torch_geometric
from mace.calculators.foundations_models import mace_off
# Newer mace-torch (>= 0.3.10) requires explicit interaction-block
# classes for ScaleShiftMACE; older versions defaulted them silently.
from mace.modules import (
RealAgnosticInteractionBlock,
RealAgnosticResidualInteractionBlock,
)
# mace-torch >= 0.3.10 requires Irreps objects (not strings) for the
# hidden_irreps and MLP_irreps arguments; older versions accepted both.
from e3nn import o3
MACE_AVAILABLE = True
except ImportError:
MACE_AVAILABLE = False
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.dataset_adapter import (
load_samples,
build_target,
split_indices,
element_set,
)
def _to_mace_atomic_data(sample, z_table, cutoff: float = 5.0):
"""Convert a QM9Sample to a MACE AtomicData object.
The z_table must be built once at the call site from the *full* element
set (H/C/N/O/F for QM9). Building it per-molecule from
`config.atomic_numbers` produces variable-width `node_attrs` and
`Batch.from_data_list` then refuses to concatenate molecules with
different element subsets.
"""
from mace.data import AtomicData, Configuration
# AtomicData requires *some* energy in the Configuration to satisfy
# mace-torch's schema; the value is overridden per-sample after
# construction in main(). For QM9 samples we use a non-zero placeholder
# (sample.properties[4]) to avoid all-zero degenerate normalization
# paths inside the AtomicData constructor; for QM7-X samples (no
# `properties` attr) the placeholder is 0.0.
placeholder_energy = 0.0
if hasattr(sample, "properties"):
try:
placeholder_energy = float(sample.properties[4])
except (IndexError, TypeError, ValueError):
placeholder_energy = 0.0
config = Configuration(
atomic_numbers=np.asarray(sample.Z, dtype=int),
positions=np.asarray(sample.coords, dtype=np.float64),
properties={"energy": placeholder_energy},
property_weights={"energy": 1.0},
)
return AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)
def main():
if not MACE_AVAILABLE:
raise RuntimeError("mace-torch is not installed. pip install mace-torch==0.3.6")
ap = argparse.ArgumentParser()
ap.add_argument("--target", required=True)
ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
ap.add_argument("--qm9_dir",
help="QM9 .xyz directory; required when --dataset qm9")
ap.add_argument("--qm7x_dir",
help="QM7-X HDF5 directory; required when --dataset qm7x")
ap.add_argument("--max_molecules", type=int, default=None)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--batch", type=int, default=32)
ap.add_argument("--cutoff", type=float, default=5.0)
ap.add_argument("--max_ell", type=int, default=3)
ap.add_argument("--correlation", type=int, default=3)
ap.add_argument("--n_features", type=int, default=128)
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
if args.dataset == "qm9" and not args.qm9_dir:
raise SystemExit("--qm9_dir is required when --dataset qm9")
if args.dataset == "qm7x" and not args.qm7x_dir:
raise SystemExit("--qm7x_dir is required when --dataset qm7x")
data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# Output layout: results/mace/<target>/seed<k>.json for QM9 (preserved
# for backwards compat); results/mace/qm7x/<target>/seed<k>.json for
# QM7-X to keep the two datasets cleanly separated in disk + audits.
if args.dataset == "qm9":
out_dir = Path(args.out_dir) / "mace" / args.target
else:
out_dir = Path(args.out_dir) / "mace" / "qm7x" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
samples = load_samples(args.dataset, data_dir,
max_molecules=args.max_molecules)
n_total = len(samples)
train_idx, val_idx, test_idx = split_indices(args.dataset, n_total,
seed=args.seed)
y = build_target(args.dataset, samples, args.target)
y_mean = float(y.mean()); y_std = float(y.std() + 1e-8)
z_table = tools.get_atomic_number_table_from_zs(element_set(args.dataset))
n_elements = len(z_table.zs)
# Build MACE configuration
atomic_energies = np.zeros(n_elements)
model = modules.ScaleShiftMACE(
r_max=args.cutoff,
num_bessel=8,
num_polynomial_cutoff=5,
max_ell=args.max_ell,
correlation=args.correlation,
num_interactions=2,
num_elements=n_elements,
hidden_irreps=o3.Irreps("128x0e + 128x1o") if args.n_features >= 128 else o3.Irreps("64x0e + 32x1o"),
MLP_irreps=o3.Irreps("16x0e"),
atomic_energies=atomic_energies,
avg_num_neighbors=8.0,
atomic_numbers=z_table.zs,
gate=torch.nn.functional.silu,
interaction_cls=RealAgnosticResidualInteractionBlock,
interaction_cls_first=RealAgnosticInteractionBlock,
atomic_inter_scale=y_std,
atomic_inter_shift=y_mean,
).to(device)
train_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in train_idx]
val_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in val_idx]
test_data = [_to_mace_atomic_data(samples[i], z_table, cutoff=args.cutoff) for i in test_idx]
# Override the per-sample target after AtomicData construction
for i, k in enumerate(train_idx):
train_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)
for i, k in enumerate(val_idx):
val_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)
for i, k in enumerate(test_idx):
test_data[i].energy = torch.tensor([float(y[k])], dtype=torch.float64)
train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_data, batch_size=args.batch, shuffle=True, drop_last=False, num_workers=4,
)
val_loader = torch_geometric.dataloader.DataLoader(
dataset=val_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4,
)
test_loader = torch_geometric.dataloader.DataLoader(
dataset=test_data, batch_size=args.batch, shuffle=False, drop_last=False, num_workers=4,
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=15)
loss_fn = torch.nn.MSELoss()
best_val, best_state, wait = float("inf"), None, 0
t0 = time.time()
for ep in range(args.epochs):
model.train()
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch, training=True, compute_force=False)
target = batch.energy.to(out["energy"].dtype)
loss = loss_fn(out["energy"], target)
loss.backward()
optimizer.step()
# Validate
model.eval()
with torch.no_grad():
val_loss = 0.0; n = 0
for batch in val_loader:
batch = batch.to(device)
out = model(batch, training=False, compute_force=False)
target = batch.energy.to(out["energy"].dtype)
val_loss += ((out["energy"] - target) ** 2).sum().item()
n += target.numel()
val_loss /= max(n, 1)
scheduler.step(val_loss)
if val_loss < best_val:
best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
else:
wait += 1
if wait >= 25:
break
train_time = time.time() - t0
if best_state:
model.load_state_dict(best_state)
# Test
model.eval()
preds, labels = [], []
with torch.no_grad():
for batch in test_loader:
batch = batch.to(device)
out = model(batch, training=False, compute_force=False)
preds.append(out["energy"].cpu().numpy())
labels.append(batch.energy.cpu().numpy())
preds = np.concatenate(preds); labels = np.concatenate(labels)
ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
r2 = float(1 - ss_res / ss_tot)
rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
mae = float(np.abs(labels - preds).mean())
result = {
"method": "mace",
"target": args.target,
"seed": args.seed,
"dataset": args.dataset,
"n_total": n_total,
"max_ell": args.max_ell,
"correlation": args.correlation,
"n_params": sum(p.numel() for p in model.parameters()),
"train_time_s": train_time,
"r2": r2, "rmse": rmse, "mae": mae,
"predictions": preds.reshape(-1).tolist(),
"labels": labels.reshape(-1).tolist(),
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] MACE R2={r2:.4f} MAE={mae:.4g} -> {out_file}")
if __name__ == "__main__":
main()