"""MLP baselines (Standard / Invariant / Augmented). These match the architectures specified in the manuscript Methods section "Baseline Architectures and Training" exactly: Hidden = [64, 32], ReLU, Adam (β1=0.9, β2=0.999, ε=1e-8), He init, early stopping on validation MSE (patience 20). Three input pipelines: standard : raw frontal slice X(:, e), z-norm invariant : [mean, std, min, max]_g X, z-norm augmented : raw slice; training set replicated |G| times across orbit """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch import torch.nn as nn from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX from data.featurizers import cyclic_angular_features, octahedral_features, stack_batch class MLP(nn.Module): def __init__(self, in_dim: int, hidden=(64, 32), out_dim: int = 1): super().__init__() layers = [] prev = in_dim for h in hidden: layers += [nn.Linear(prev, h), nn.ReLU()] prev = h layers.append(nn.Linear(prev, out_dim)) self.net = nn.Sequential(*layers) # He init for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") nn.init.zeros_(m.bias) def forward(self, x): return self.net(x) def _build_inputs(X: torch.Tensor, mode: str, y: np.ndarray): """Build per-mode (X_train, y_train) representation. X is (N, n_f, n_g).""" N, n_f, n_g = X.shape if mode == "standard": return X[:, :, 0].numpy(), y if mode == "invariant": mean = X.mean(dim=2) std = X.std(dim=2) mn = X.min(dim=2).values mx = X.max(dim=2).values return torch.cat([mean, std, mn, mx], dim=1).numpy(), y if mode == "augmented": # Replicate (N * n_g, n_f) inputs and (N * n_g,) targets Xa = X.permute(0, 2, 1).reshape(N * n_g, n_f).numpy() if y.ndim == 1: ya = np.tile(y, n_g) else: ya = np.tile(y, (n_g, 1)) return Xa, ya raise ValueError(mode) def main(): ap = argparse.ArgumentParser() ap.add_argument("--mode", choices=["standard", "invariant", "augmented"], required=True) ap.add_argument("--target", required=True) ap.add_argument("--qm9_dir", required=True) ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic") ap.add_argument("--n_rot", type=int, default=12) ap.add_argument("--max_molecules", type=int, default=None) ap.add_argument("--seed", type=int, default=0) ap.add_argument("--epochs", type=int, default=300) ap.add_argument("--lr", type=float, default=0.003) ap.add_argument("--batch", type=int, default=256) ap.add_argument("--patience", type=int, default=20) ap.add_argument("--out_dir", default="results/") ap.add_argument("--device", default="cuda") args = ap.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") out_dir = Path(args.out_dir) / f"mlp_{args.mode}" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules) samples = [ds[i] for i in range(len(ds))] train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed) if args.group == "octahedral": X = stack_batch(samples, octahedral_features) else: X = stack_batch(samples, cyclic_angular_features, n_rot=args.n_rot) if args.target in PROPERTY_INDEX: y = np.array([s.properties[PROPERTY_INDEX[args.target]] for s in samples]) else: from train_starg import _build_target y = _build_target(samples, args.target) X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx] y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx] X_tr_in, y_tr_in = _build_inputs(X_tr, args.mode, y_tr) # For validation/test always use the un-augmented frontal slice X_va_in, y_va_in = _build_inputs(X_va, "standard" if args.mode == "augmented" else args.mode, y_va) X_te_in, y_te_in = _build_inputs(X_te, "standard" if args.mode == "augmented" else args.mode, y_te) mu = X_tr_in.mean(axis=0); sig = X_tr_in.std(axis=0) + 1e-8 X_tr_in = (X_tr_in - mu) / sig X_va_in = (X_va_in - mu) / sig X_te_in = (X_te_in - mu) / sig out_dim = 1 if y_tr.ndim == 1 else y_tr.shape[1] model = MLP(X_tr_in.shape[1], hidden=(64, 32), out_dim=out_dim).to(device) opt = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) Xt = torch.tensor(X_tr_in, dtype=torch.float32, device=device) yt = torch.tensor(y_tr_in, dtype=torch.float32, device=device) if yt.ndim == 1: yt = yt.unsqueeze(-1) Xv = torch.tensor(X_va_in, dtype=torch.float32, device=device) yv = torch.tensor(y_va_in, dtype=torch.float32, device=device) if yv.ndim == 1: yv = yv.unsqueeze(-1) Xe = torch.tensor(X_te_in, dtype=torch.float32, device=device) best_val, best_state, wait = float("inf"), None, 0 t0 = time.time() for ep in range(args.epochs): perm = torch.randperm(Xt.shape[0]) for i in range(0, Xt.shape[0], args.batch): idx = perm[i : i + args.batch] opt.zero_grad() pred = model(Xt[idx]) loss = ((pred - yt[idx]) ** 2).mean() loss.backward() opt.step() with torch.no_grad(): val_pred = model(Xv) val_loss = ((val_pred - yv) ** 2).mean().item() if val_loss < best_val: best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0 else: wait += 1 if wait >= args.patience: break train_time = time.time() - t0 if best_state: model.load_state_dict(best_state) with torch.no_grad(): y_pred = model(Xe).cpu().numpy() if y_te_in.ndim == 1: y_pred = y_pred.flatten() ss_res = ((y_te_in - y_pred) ** 2).sum() ss_tot = ((y_te_in - y_te_in.mean(axis=0, keepdims=True)) ** 2).sum() r2 = float(1 - ss_res / ss_tot) rmse = float(np.sqrt(((y_te_in - y_pred) ** 2).mean())) mae = float(np.abs(y_te_in - y_pred).mean()) result = { "method": f"mlp_{args.mode}", "target": args.target, "group": args.group, "seed": args.seed, "n_total": len(ds), "n_params": sum(p.numel() for p in model.parameters()), "train_time_s": train_time, "r2": r2, "rmse": rmse, "mae": mae, "predictions": ( y_pred.tolist() if np.asarray(y_pred).ndim == 1 else y_pred.tolist() ), "test_idx": test_idx.tolist(), } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] {args.mode} R2={r2:.4f} RMSE={rmse:.4g} -> {out_file}") if __name__ == "__main__": main()