tensor-group-sym / python / large_scale / train_baseline_mlp.py
train_baseline_mlp.py
Raw
"""MLP baselines (Standard / Invariant / Augmented).

These match the architectures specified in the manuscript Methods section
"Baseline Architectures and Training" exactly:
    Hidden = [64, 32], ReLU, Adam (β1=0.9, β2=0.999, ε=1e-8), He init,
    early stopping on validation MSE (patience 20).

Three input pipelines:
    standard  : raw frontal slice X(:, e), z-norm
    invariant : [mean, std, min, max]_g X, z-norm
    augmented : raw slice; training set replicated |G| times across orbit
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn

from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.featurizers import cyclic_angular_features, octahedral_features, stack_batch


class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden=(64, 32), out_dim: int = 1):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden:
            layers += [nn.Linear(prev, h), nn.ReLU()]
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.net = nn.Sequential(*layers)
        # He init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


def _build_inputs(X: torch.Tensor, mode: str, y: np.ndarray):
    """Build per-mode (X_train, y_train) representation. X is (N, n_f, n_g)."""
    N, n_f, n_g = X.shape
    if mode == "standard":
        return X[:, :, 0].numpy(), y
    if mode == "invariant":
        mean = X.mean(dim=2)
        std = X.std(dim=2)
        mn = X.min(dim=2).values
        mx = X.max(dim=2).values
        return torch.cat([mean, std, mn, mx], dim=1).numpy(), y
    if mode == "augmented":
        # Replicate (N * n_g, n_f) inputs and (N * n_g,) targets
        Xa = X.permute(0, 2, 1).reshape(N * n_g, n_f).numpy()
        if y.ndim == 1:
            ya = np.tile(y, n_g)
        else:
            ya = np.tile(y, (n_g, 1))
        return Xa, ya
    raise ValueError(mode)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["standard", "invariant", "augmented"], required=True)
    ap.add_argument("--target", required=True)
    ap.add_argument("--qm9_dir", required=True)
    ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic")
    ap.add_argument("--n_rot", type=int, default=12)
    ap.add_argument("--max_molecules", type=int, default=None)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--epochs", type=int, default=300)
    ap.add_argument("--lr", type=float, default=0.003)
    ap.add_argument("--batch", type=int, default=256)
    ap.add_argument("--patience", type=int, default=20)
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--device", default="cuda")
    args = ap.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    out_dir = Path(args.out_dir) / f"mlp_{args.mode}" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
    samples = [ds[i] for i in range(len(ds))]
    train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed)

    if args.group == "octahedral":
        X = stack_batch(samples, octahedral_features)
    else:
        X = stack_batch(samples, cyclic_angular_features, n_rot=args.n_rot)

    if args.target in PROPERTY_INDEX:
        y = np.array([s.properties[PROPERTY_INDEX[args.target]] for s in samples])
    else:
        from train_starg import _build_target
        y = _build_target(samples, args.target)

    X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx]
    y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx]

    X_tr_in, y_tr_in = _build_inputs(X_tr, args.mode, y_tr)
    # For validation/test always use the un-augmented frontal slice
    X_va_in, y_va_in = _build_inputs(X_va, "standard" if args.mode == "augmented" else args.mode, y_va)
    X_te_in, y_te_in = _build_inputs(X_te, "standard" if args.mode == "augmented" else args.mode, y_te)

    mu = X_tr_in.mean(axis=0); sig = X_tr_in.std(axis=0) + 1e-8
    X_tr_in = (X_tr_in - mu) / sig
    X_va_in = (X_va_in - mu) / sig
    X_te_in = (X_te_in - mu) / sig

    out_dim = 1 if y_tr.ndim == 1 else y_tr.shape[1]
    model = MLP(X_tr_in.shape[1], hidden=(64, 32), out_dim=out_dim).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8)

    Xt = torch.tensor(X_tr_in, dtype=torch.float32, device=device)
    yt = torch.tensor(y_tr_in, dtype=torch.float32, device=device)
    if yt.ndim == 1: yt = yt.unsqueeze(-1)
    Xv = torch.tensor(X_va_in, dtype=torch.float32, device=device)
    yv = torch.tensor(y_va_in, dtype=torch.float32, device=device)
    if yv.ndim == 1: yv = yv.unsqueeze(-1)
    Xe = torch.tensor(X_te_in, dtype=torch.float32, device=device)

    best_val, best_state, wait = float("inf"), None, 0
    t0 = time.time()
    for ep in range(args.epochs):
        perm = torch.randperm(Xt.shape[0])
        for i in range(0, Xt.shape[0], args.batch):
            idx = perm[i : i + args.batch]
            opt.zero_grad()
            pred = model(Xt[idx])
            loss = ((pred - yt[idx]) ** 2).mean()
            loss.backward()
            opt.step()
        with torch.no_grad():
            val_pred = model(Xv)
            val_loss = ((val_pred - yv) ** 2).mean().item()
        if val_loss < best_val:
            best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
        else:
            wait += 1
            if wait >= args.patience:
                break
    train_time = time.time() - t0
    if best_state:
        model.load_state_dict(best_state)
    with torch.no_grad():
        y_pred = model(Xe).cpu().numpy()
        if y_te_in.ndim == 1:
            y_pred = y_pred.flatten()

    ss_res = ((y_te_in - y_pred) ** 2).sum()
    ss_tot = ((y_te_in - y_te_in.mean(axis=0, keepdims=True)) ** 2).sum()
    r2 = float(1 - ss_res / ss_tot)
    rmse = float(np.sqrt(((y_te_in - y_pred) ** 2).mean()))
    mae = float(np.abs(y_te_in - y_pred).mean())

    result = {
        "method": f"mlp_{args.mode}",
        "target": args.target,
        "group": args.group,
        "seed": args.seed,
        "n_total": len(ds),
        "n_params": sum(p.numel() for p in model.parameters()),
        "train_time_s": train_time,
        "r2": r2, "rmse": rmse, "mae": mae,
        "predictions": (
            y_pred.tolist() if np.asarray(y_pred).ndim == 1 else y_pred.tolist()
        ),
        "test_idx": test_idx.tolist(),
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] {args.mode} R2={r2:.4f}  RMSE={rmse:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()