tensor-group-sym / python / large_scale / train_baseline_e3nn.py
train_baseline_e3nn.py
Raw
"""SE(3)-equivariant ENN baseline using e3nn.

A small but representative SE(3)-equivariant message-passing network built
directly on top of e3nn primitives. Predicts scalar (l=0) targets for
energy-like properties and vector / rank-2 tensor targets via the
appropriate output irreps. This baseline is deliberately compact (fewer
parameters than MACE) so the comparison cleanly isolates the algebraic
structure rather than the parameter budget.

Usage:
    python train_baseline_e3nn.py --target gap --qm9_dir /path/to/qm9 \
        --seed 0 --out_dir results/

    python train_baseline_e3nn.py --target mu_vector --output_irreps "1x1o" \
        --qm9_dir /path/to/qm9 --seed 0 --out_dir results/
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn

try:
    from e3nn import o3
    from e3nn.nn import Gate
    from e3nn.o3 import FullyConnectedTensorProduct, Irreps
    E3NN_AVAILABLE = True
except ImportError:
    E3NN_AVAILABLE = False

from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.dataset_adapter import (
    load_samples,
    build_target,
    split_indices,
)


def _build_radial_edges(pos: torch.Tensor, cutoff: float = 5.0):
    """Compute edge index and edge vectors within cutoff."""
    N = pos.shape[0]
    diff = pos.unsqueeze(0) - pos.unsqueeze(1)  # (N, N, 3)
    dist = diff.norm(dim=-1)
    mask = (dist > 0) & (dist < cutoff)
    src, dst = mask.nonzero(as_tuple=True)
    edge_vec = diff[src, dst]
    edge_len = dist[src, dst]
    return src, dst, edge_vec, edge_len


class E3NNLayer(nn.Module):
    """One SE(3)-equivariant layer: TP with spherical harmonics on edges + gate."""

    def __init__(self, irreps_in: Irreps, irreps_out: Irreps, irreps_sh: Irreps):
        super().__init__()
        self.irreps_sh = irreps_sh
        # Gate scheme: split irreps_out into scalars and gated irreps
        gates = Irreps([(mul, "0e") for mul, (l, p) in irreps_out if l > 0])
        scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0])
        gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0])
        self.gate = Gate(scalars, [torch.relu] * len(scalars), gates, [torch.sigmoid] * len(gates), gated)
        self.tp = FullyConnectedTensorProduct(
            irreps_in, irreps_sh, self.gate.irreps_in, internal_weights=False, shared_weights=False,
        )
        self.weight_net = nn.Sequential(
            nn.Linear(16, 64), nn.SiLU(), nn.Linear(64, self.tp.weight_numel),
        )

    def forward(self, node_feat: torch.Tensor, edge_index, edge_vec, edge_len_emb):
        src, dst = edge_index
        sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=True, normalization="component")
        weights = self.weight_net(edge_len_emb)
        msg = self.tp(node_feat[src], sh, weights)
        out = torch.zeros(node_feat.shape[0], msg.shape[1], device=node_feat.device, dtype=msg.dtype)
        out.index_add_(0, dst, msg)
        return self.gate(out)


class E3NNModel(nn.Module):
    def __init__(
        self,
        n_atom_types: int = 10,
        irreps_hidden: str = "32x0e + 16x1o + 8x2e",
        irreps_sh: str = "1x0e + 1x1o + 1x2e",
        n_layers: int = 3,
        output_irreps: str = "1x0e",
        cutoff: float = 5.0,
    ):
        super().__init__()
        self.cutoff = cutoff
        self.embedding = nn.Embedding(n_atom_types, Irreps(irreps_hidden).count("0e"))
        self.irreps_hidden = Irreps(irreps_hidden)
        self.irreps_sh = Irreps(irreps_sh)
        self.layers = nn.ModuleList()
        irreps_in = Irreps(f"{self.irreps_hidden.count('0e')}x0e")
        for _ in range(n_layers):
            self.layers.append(E3NNLayer(irreps_in, self.irreps_hidden, self.irreps_sh))
            irreps_in = self.irreps_hidden
        self.output_irreps = Irreps(output_irreps)
        self.head = FullyConnectedTensorProduct(self.irreps_hidden, "1x0e", self.output_irreps, internal_weights=True)
        self.rbf = nn.Parameter(torch.linspace(0, cutoff, 16), requires_grad=False)

    def _rbf(self, edge_len: torch.Tensor) -> torch.Tensor:
        sigma = (self.rbf[1] - self.rbf[0])
        return torch.exp(-((edge_len.unsqueeze(-1) - self.rbf) ** 2) / (2 * sigma ** 2))

    def forward(self, z: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        src, dst, edge_vec, edge_len = _build_radial_edges(pos, cutoff=self.cutoff)
        edge_emb = self._rbf(edge_len)
        h = self.embedding(z)
        for layer in self.layers:
            h = layer(h, (src, dst), edge_vec, edge_emb)
        # Pool: sum / mean over atoms, then project to output irreps
        ones = torch.ones(h.shape[0], 1, device=h.device, dtype=h.dtype)
        out = self.head(h, ones)
        return out.mean(dim=0)


def main():
    if not E3NN_AVAILABLE:
        raise RuntimeError("e3nn is not installed. pip install e3nn==0.5.4")

    ap = argparse.ArgumentParser()
    ap.add_argument("--target", required=True)
    ap.add_argument("--dataset", choices=["qm9", "qm7x"], default="qm9")
    ap.add_argument("--qm9_dir",
                    help="QM9 .xyz directory (only when --dataset qm9)")
    ap.add_argument("--qm7x_dir",
                    help="QM7-X HDF5 directory (only when --dataset qm7x)")
    ap.add_argument("--max_molecules", type=int, default=None)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--epochs", type=int, default=200)
    ap.add_argument("--lr", type=float, default=5e-4)
    ap.add_argument("--batch", type=int, default=32)
    ap.add_argument("--output_irreps", default="1x0e",
                    help="e.g. 1x0e for scalar, 1x1o for vector, 1x2e for rank-2")
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()

    if args.dataset == "qm9" and not args.qm9_dir:
        raise SystemExit("--qm9_dir is required when --dataset qm9")
    if args.dataset == "qm7x" and not args.qm7x_dir:
        raise SystemExit("--qm7x_dir is required when --dataset qm7x")
    data_dir = args.qm9_dir if args.dataset == "qm9" else args.qm7x_dir

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    if args.dataset == "qm9":
        out_dir = Path(args.out_dir) / "e3nn" / args.target
    else:
        out_dir = Path(args.out_dir) / "e3nn" / "qm7x" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    samples = load_samples(args.dataset, data_dir,
                           max_molecules=args.max_molecules)
    n_total = len(samples)
    train_idx, val_idx, test_idx = split_indices(args.dataset, n_total,
                                                 seed=args.seed)
    y = build_target(args.dataset, samples, args.target)
    y_mean, y_std = float(y.mean()), float(y.std() + 1e-8)
    y_norm = (y - y_mean) / y_std

    model = E3NNModel(output_irreps=args.output_irreps).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10)

    def step(idx_list):
        model.train()
        opt.zero_grad()
        loss = 0.0
        for k in idx_list:
            s = samples[k]
            z = torch.tensor(s.Z, dtype=torch.long, device=device)
            pos = torch.tensor(s.coords, dtype=torch.float32, device=device)
            pos = pos - pos.mean(dim=0, keepdim=True)
            pred = model(z, pos)
            target = torch.tensor(np.atleast_1d(y_norm[k]), dtype=torch.float32, device=device)
            loss = loss + ((pred - target) ** 2).mean()
        loss = loss / len(idx_list)
        loss.backward()
        opt.step()
        return loss.item()

    def evaluate(idx_arr):
        model.eval()
        with torch.no_grad():
            preds, labels = [], []
            for k in idx_arr:
                s = samples[k]
                z = torch.tensor(s.Z, dtype=torch.long, device=device)
                pos = torch.tensor(s.coords, dtype=torch.float32, device=device)
                pos = pos - pos.mean(dim=0, keepdim=True)
                pred = model(z, pos).cpu().numpy() * y_std + y_mean
                preds.append(pred); labels.append(y[k])
            preds = np.array(preds).reshape(len(idx_arr), -1)
            labels = np.array(labels).reshape(len(idx_arr), -1)
            if preds.shape[1] == 1:
                preds = preds.ravel()
                labels = labels.ravel()
            return preds, labels

    best_val, best_state, wait = float("inf"), None, 0
    t0 = time.time()
    rng = np.random.default_rng(args.seed)
    for ep in range(args.epochs):
        order = rng.permutation(train_idx)
        losses = []
        for i in range(0, len(order), args.batch):
            losses.append(step(order[i : i + args.batch]))
        val_preds, val_labels = evaluate(val_idx)
        val_loss = float(((val_preds - val_labels) ** 2).mean())
        sched.step(val_loss)
        if val_loss < best_val:
            best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
        else:
            wait += 1
            if wait >= 20:
                break
    train_time = time.time() - t0
    if best_state:
        model.load_state_dict(best_state)
    preds, labels = evaluate(test_idx)
    ss_res = ((labels - preds) ** 2).sum(); ss_tot = ((labels - labels.mean()) ** 2).sum()
    r2 = float(1 - ss_res / ss_tot)
    rmse = float(np.sqrt(((labels - preds) ** 2).mean()))
    mae = float(np.abs(labels - preds).mean())

    result = {
        "method": "e3nn",
        "target": args.target,
        "output_irreps": args.output_irreps,
        "seed": args.seed,
        "dataset": args.dataset,
        "n_total": n_total,
        "n_params": sum(p.numel() for p in model.parameters()),
        "train_time_s": train_time,
        "r2": r2, "rmse": rmse, "mae": mae,
        "predictions": np.asarray(preds).reshape(-1).tolist(),
        "labels": np.asarray(labels).reshape(-1).tolist(),
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] e3nn R2={r2:.4f}  MAE={mae:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()