"""SE(3)-equivariant ENN baseline using e3nn. A small but representative SE(3)-equivariant message-passing network built directly on top of e3nn primitives. Predicts scalar (l=0) targets for energy-like properties and vector / rank-2 tensor targets via the appropriate output irreps. This baseline is deliberately compact (fewer parameters than MACE) so the comparison cleanly isolates the algebraic structure rather than the parameter budget. Usage: python train_baseline_e3nn.py --target gap --qm9_dir /path/to/qm9 \ --seed 0 --out_dir results/ python train_baseline_e3nn.py --target mu_vector --output_irreps "1x1o" \ --qm9_dir /path/to/qm9 --seed 0 --out_dir results/ """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch import torch.nn as nn try: from e3nn import o3 from e3nn.nn import Gate from e3nn.o3 import FullyConnectedTensorProduct, Irreps E3NN_AVAILABLE = True except ImportError: E3NN_AVAILABLE = False from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX from data.dataset_adapter import ( load_samples, build_target, split_indices, ) def _build_radial_edges(pos: torch.Tensor, cutoff: float = 5.0): """Compute edge index and edge vectors within cutoff.""" N = pos.shape[0] diff = pos.unsqueeze(0) - pos.unsqueeze(1) # (N, N, 3) dist = diff.norm(dim=-1) mask = (dist > 0) & (dist < cutoff) src, dst = mask.nonzero(as_tuple=True) edge_vec = diff[src, dst] edge_len = dist[src, dst] return src, dst, edge_vec, edge_len class E3NNLayer(nn.Module): """One SE(3)-equivariant layer: TP with spherical harmonics on edges + gate.""" def __init__(self, irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps): super().__init__() self.irreps_sh = irreps_sh # Gate scheme: split irreps_out into scalars and gated irreps gates = Irreps([(mul, "0e") for mul, (l, p) in irreps_out if l > 0]) scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0]) gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0]) self.gate = Gate(scalars, [torch.relu] * len(scalars), gates, [torch.sigmoid] * len(gates), gated) self.tp = FullyConnectedTensorProduct( irreps_in, irreps_sh, self.gate.irreps_in, internal_weights=False, shared_weights=False, ) self.weight_net = nn.Sequential( nn.Linear(16, 64), nn.SiLU(), nn.Linear(64, self.tp.weight_numel), ) def forward(self, node_feat: torch.Tensor, edge_index, edge_vec, edge_len_emb): src, dst = edge_index sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=True, normalization="component") weights = self.weight_net(edge_len_emb) msg = self.tp(node_feat[src], sh, weights) out = torch.zeros(node_feat.shape[0], msg.shape[1], device=node_feat.device, dtype=msg.dtype) out.index_add_(0, dst, msg) return self.gate(out) class E3NNModel(nn.Module): def __init__( self, n_atom_types: int = 10, irreps_hidden: str = "32x0e + 16x1o + 8x2e", irreps_sh: str = "1x0e + 1x1o + 1x2e", n_layers: int = 3, output_irreps: str = "1x0e", cutoff: float = 5.0, ): super().__init__() self.cutoff = cutoff self.embedding = nn.Embedding(n_atom_types, Irreps(irreps_hidden).count("0e")) self.irreps_hidden = Irreps(irreps_hidden) self.irreps_sh = Irreps(irreps_sh) self.layers = nn.ModuleList() irreps_in = Irreps(f"{self.irreps_hidden.count('0e')}x0e") for _ in range(n_layers): self.layers.append(E3NNLayer(irreps_in, self.irreps_hidden, self.irreps_sh)) irreps_in = self.irreps_hidden self.output_irreps = Irreps(output_irreps) self.head = FullyConnectedTensorProduct(self.irreps_hidden, "1x0e", self.output_irreps, internal_weights=True) self.rbf = nn.Parameter(torch.linspace(0, cutoff, 16), requires_grad=False) def _rbf(self, edge_len: torch.Tensor) -> torch.Tensor: sigma = (self.rbf[1] - self.rbf[0]) return torch.exp(-((edge_len.unsqueeze(-1) - self.rbf) ** 2) / (2 * sigma ** 2)) def forward(self, z: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: src, dst, edge_vec, edge_len = _build_radial_edges(pos, cutoff=self.cutoff) edge_emb = self._rbf(edge_len) h = self.embedding(z) for layer in self.layers: h = layer(h, (src, dst), edge_vec, edge_emb) # Pool: sum / mean over atoms, then project to output irreps ones = torch.ones(h.shape[0], 1, device=h.device, dtype=h.dtype) out = self.head(h, ones) return out.mean(dim=0) def main(): if not E3NN_AVAILABLE: raise RuntimeError("e3nn is not installed. pip install e3nn==0.5.4") ap = argparse.ArgumentParser() ap.add_argument("--target", required=True) ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9") ap.add_argument("--qm9_dir", help="QM9 .xyz directory (only when --dataset qm9)") ap.add_argument("--qm7x_dir", help="QM7-X HDF5 directory (only when --dataset qm7x)") ap.add_argument("--max_molecules", type=int, default=None) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--epochs", type=int, default=200) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--batch", type=int, default=32) ap.add_argument("--output_irreps", default="1x0e", help="e.g. 1x0e for scalar, 1x1o for vector, 1x2e for rank-2") ap.add_argument("--out_dir", default="results/") ap.add_argument("--device", default="cuda") args = ap.parse_args() if args.dataset == "qm9" and not args.qm9_dir: raise SystemExit("--qm9_dir is required when --dataset qm9") if args.dataset == "qm7x" and not args.qm7x_dir: raise SystemExit("--qm7x_dir is required when --dataset qm7x") data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") if args.dataset == "qm9": out_dir = Path(args.out_dir) / "e3nn" / args.target else: out_dir = Path(args.out_dir) / "e3nn" / "qm7x" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" samples = load_samples(args.dataset, data_dir, max_molecules=args.max_molecules) n_total = len(samples) train_idx, val_idx, test_idx = split_indices(args.dataset, n_total, seed=args.seed) y = build_target(args.dataset, samples, args.target) y_mean, y_std = float(y.mean()), float(y.std() + 1e-8) y_norm = (y - y_mean) / y_std model = E3NNModel(output_irreps=args.output_irreps).to(device) opt = torch.optim.Adam(model.parameters(), lr=args.lr) sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10) def step(idx_list): model.train() opt.zero_grad() loss = 0.0 for k in idx_list: s = samples[k] z = torch.tensor(s.Z, dtype=torch.long, device=device) pos = torch.tensor(s.coords, dtype=torch.float32, device=device) pos = pos - pos.mean(dim=0, keepdim=True) pred = model(z, pos) target = torch.tensor(np.atleast_1d(y_norm[k]), dtype=torch.float32, device=device) loss = loss + ((pred - target) ** 2).mean() loss = loss / len(idx_list) loss.backward() opt.step() return loss.item() def evaluate(idx_arr): model.eval() with torch.no_grad(): preds, labels = [], [] for k in idx_arr: s = samples[k] z = torch.tensor(s.Z, dtype=torch.long, device=device) pos = torch.tensor(s.coords, dtype=torch.float32, device=device) pos = pos - pos.mean(dim=0, keepdim=True) pred = model(z, pos).cpu().numpy() * y_std + y_mean preds.append(pred); labels.append(y[k]) preds = np.array(preds).reshape(len(idx_arr), -1) labels = np.array(labels).reshape(len(idx_arr), -1) if preds.shape[1] == 1: preds = preds.ravel() labels = labels.ravel() return preds, labels best_val, best_state, wait = float("inf"), None, 0 t0 = time.time() rng = np.random.default_rng(args.seed) for ep in range(args.epochs): order = rng.permutation(train_idx) losses = [] for i in range(0, len(order), args.batch): losses.append(step(order[i : i + args.batch])) val_preds, val_labels = evaluate(val_idx) val_loss = float(((val_preds - val_labels) ** 2).mean()) sched.step(val_loss) if val_loss < best_val: best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0 else: wait += 1 if wait >= 20: break train_time = time.time() - t0 if best_state: model.load_state_dict(best_state) preds, labels = evaluate(test_idx) ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum() r2 = float(1 - ss_res / ss_tot) rmse = float(np.sqrt(((labels - preds) ** 2).mean())) mae = float(np.abs(labels - preds).mean()) result = { "method": "e3nn", "target": args.target, "output_irreps": args.output_irreps, "seed": args.seed, "dataset": args.dataset, "n_total": n_total, "n_params": sum(p.numel() for p in model.parameters()), "train_time_s": train_time, "r2": r2, "rmse": rmse, "mae": mae, "predictions": np.asarray(preds).reshape(-1).tolist(), "labels": np.asarray(labels).reshape(-1).tolist(), } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] e3nn R2={r2:.4f} MAE={mae:.4g} -> {out_file}") if __name__ == "__main__": main()