"""SE(3)-equivariant ENN baseline using e3nn.
A small but representative SE(3)-equivariant message-passing network built
directly on top of e3nn primitives. Predicts scalar (l=0) targets for
energy-like properties and vector / rank-2 tensor targets via the
appropriate output irreps. This baseline is deliberately compact (fewer
parameters than MACE) so the comparison cleanly isolates the algebraic
structure rather than the parameter budget.
Usage:
python train_baseline_e3nn.py --target gap --qm9_dir /path/to/qm9 \
--seed 0 --out_dir results/
python train_baseline_e3nn.py --target mu_vector --output_irreps "1x1o" \
--qm9_dir /path/to/qm9 --seed 0 --out_dir results/
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
try:
from e3nn import o3
from e3nn.nn import Gate
from e3nn.o3 import FullyConnectedTensorProduct, Irreps
E3NN_AVAILABLE = True
except ImportError:
E3NN_AVAILABLE = False
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.dataset_adapter import (
load_samples,
build_target,
split_indices,
)
def _build_radial_edges(pos: torch.Tensor, cutoff: float = 5.0):
"""Compute edge index and edge vectors within cutoff."""
N = pos.shape[0]
diff = pos.unsqueeze(0) - pos.unsqueeze(1) # (N, N, 3)
dist = diff.norm(dim=-1)
mask = (dist > 0) & (dist < cutoff)
src, dst = mask.nonzero(as_tuple=True)
edge_vec = diff[src, dst]
edge_len = dist[src, dst]
return src, dst, edge_vec, edge_len
class E3NNLayer(nn.Module):
"""One SE(3)-equivariant layer: TP with spherical harmonics on edges + gate."""
def __init__(self, irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps):
super().__init__()
self.irreps_sh = irreps_sh
# Gate scheme: split irreps_out into scalars and gated irreps
gates = Irreps([(mul, "0e") for mul, (l, p) in irreps_out if l > 0])
scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0])
gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0])
self.gate = Gate(scalars, [torch.relu] * len(scalars), gates, [torch.sigmoid] * len(gates), gated)
self.tp = FullyConnectedTensorProduct(
irreps_in, irreps_sh, self.gate.irreps_in, internal_weights=False, shared_weights=False,
)
self.weight_net = nn.Sequential(
nn.Linear(16, 64), nn.SiLU(), nn.Linear(64, self.tp.weight_numel),
)
def forward(self, node_feat: torch.Tensor, edge_index, edge_vec, edge_len_emb):
src, dst = edge_index
sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=True, normalization="component")
weights = self.weight_net(edge_len_emb)
msg = self.tp(node_feat[src], sh, weights)
out = torch.zeros(node_feat.shape[0], msg.shape[1], device=node_feat.device, dtype=msg.dtype)
out.index_add_(0, dst, msg)
return self.gate(out)
class E3NNModel(nn.Module):
def __init__(
self,
n_atom_types: int = 10,
irreps_hidden: str = "32x0e + 16x1o + 8x2e",
irreps_sh: str = "1x0e + 1x1o + 1x2e",
n_layers: int = 3,
output_irreps: str = "1x0e",
cutoff: float = 5.0,
):
super().__init__()
self.cutoff = cutoff
self.embedding = nn.Embedding(n_atom_types, Irreps(irreps_hidden).count("0e"))
self.irreps_hidden = Irreps(irreps_hidden)
self.irreps_sh = Irreps(irreps_sh)
self.layers = nn.ModuleList()
irreps_in = Irreps(f"{self.irreps_hidden.count('0e')}x0e")
for _ in range(n_layers):
self.layers.append(E3NNLayer(irreps_in, self.irreps_hidden, self.irreps_sh))
irreps_in = self.irreps_hidden
self.output_irreps = Irreps(output_irreps)
self.head = FullyConnectedTensorProduct(self.irreps_hidden, "1x0e", self.output_irreps, internal_weights=True)
self.rbf = nn.Parameter(torch.linspace(0, cutoff, 16), requires_grad=False)
def _rbf(self, edge_len: torch.Tensor) -> torch.Tensor:
sigma = (self.rbf[1] - self.rbf[0])
return torch.exp(-((edge_len.unsqueeze(-1) - self.rbf) ** 2) / (2 * sigma ** 2))
def forward(self, z: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
src, dst, edge_vec, edge_len = _build_radial_edges(pos, cutoff=self.cutoff)
edge_emb = self._rbf(edge_len)
h = self.embedding(z)
for layer in self.layers:
h = layer(h, (src, dst), edge_vec, edge_emb)
# Pool: sum / mean over atoms, then project to output irreps
ones = torch.ones(h.shape[0], 1, device=h.device, dtype=h.dtype)
out = self.head(h, ones)
return out.mean(dim=0)
def main():
if not E3NN_AVAILABLE:
raise RuntimeError("e3nn is not installed. pip install e3nn==0.5.4")
ap = argparse.ArgumentParser()
ap.add_argument("--target", required=True)
ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
ap.add_argument("--qm9_dir",
help="QM9 .xyz directory (only when --dataset qm9)")
ap.add_argument("--qm7x_dir",
help="QM7-X HDF5 directory (only when --dataset qm7x)")
ap.add_argument("--max_molecules", type=int, default=None)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--epochs", type=int, default=200)
ap.add_argument("--lr", type=float, default=5e-4)
ap.add_argument("--batch", type=int, default=32)
ap.add_argument("--output_irreps", default="1x0e",
help="e.g. 1x0e for scalar, 1x1o for vector, 1x2e for rank-2")
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
if args.dataset == "qm9" and not args.qm9_dir:
raise SystemExit("--qm9_dir is required when --dataset qm9")
if args.dataset == "qm7x" and not args.qm7x_dir:
raise SystemExit("--qm7x_dir is required when --dataset qm7x")
data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
if args.dataset == "qm9":
out_dir = Path(args.out_dir) / "e3nn" / args.target
else:
out_dir = Path(args.out_dir) / "e3nn" / "qm7x" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
samples = load_samples(args.dataset, data_dir,
max_molecules=args.max_molecules)
n_total = len(samples)
train_idx, val_idx, test_idx = split_indices(args.dataset, n_total,
seed=args.seed)
y = build_target(args.dataset, samples, args.target)
y_mean, y_std = float(y.mean()), float(y.std() + 1e-8)
y_norm = (y - y_mean) / y_std
model = E3NNModel(output_irreps=args.output_irreps).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10)
def step(idx_list):
model.train()
opt.zero_grad()
loss = 0.0
for k in idx_list:
s = samples[k]
z = torch.tensor(s.Z, dtype=torch.long, device=device)
pos = torch.tensor(s.coords, dtype=torch.float32, device=device)
pos = pos - pos.mean(dim=0, keepdim=True)
pred = model(z, pos)
target = torch.tensor(np.atleast_1d(y_norm[k]), dtype=torch.float32, device=device)
loss = loss + ((pred - target) ** 2).mean()
loss = loss / len(idx_list)
loss.backward()
opt.step()
return loss.item()
def evaluate(idx_arr):
model.eval()
with torch.no_grad():
preds, labels = [], []
for k in idx_arr:
s = samples[k]
z = torch.tensor(s.Z, dtype=torch.long, device=device)
pos = torch.tensor(s.coords, dtype=torch.float32, device=device)
pos = pos - pos.mean(dim=0, keepdim=True)
pred = model(z, pos).cpu().numpy() * y_std + y_mean
preds.append(pred); labels.append(y[k])
preds = np.array(preds).reshape(len(idx_arr), -1)
labels = np.array(labels).reshape(len(idx_arr), -1)
if preds.shape[1] == 1:
preds = preds.ravel()
labels = labels.ravel()
return preds, labels
best_val, best_state, wait = float("inf"), None, 0
t0 = time.time()
rng = np.random.default_rng(args.seed)
for ep in range(args.epochs):
order = rng.permutation(train_idx)
losses = []
for i in range(0, len(order), args.batch):
losses.append(step(order[i : i + args.batch]))
val_preds, val_labels = evaluate(val_idx)
val_loss = float(((val_preds - val_labels) ** 2).mean())
sched.step(val_loss)
if val_loss < best_val:
best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
else:
wait += 1
if wait >= 20:
break
train_time = time.time() - t0
if best_state:
model.load_state_dict(best_state)
preds, labels = evaluate(test_idx)
ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
r2 = float(1 - ss_res / ss_tot)
rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
mae = float(np.abs(labels - preds).mean())
result = {
"method": "e3nn",
"target": args.target,
"output_irreps": args.output_irreps,
"seed": args.seed,
"dataset": args.dataset,
"n_total": n_total,
"n_params": sum(p.numel() for p in model.parameters()),
"train_time_s": train_time,
"r2": r2, "rmse": rmse, "mae": mae,
"predictions": np.asarray(preds).reshape(-1).tolist(),
"labels": np.asarray(labels).reshape(-1).tolist(),
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] e3nn R2={r2:.4f} MAE={mae:.4g} -> {out_file}")
if __name__ == "__main__":
main()