"""MLP baselines (Standard / Invariant / Augmented).
These match the architectures specified in the manuscript Methods section
"Baseline Architectures and Training" exactly:
Hidden = [64, 32], ReLU, Adam (β1=0.9, β2=0.999, ε=1e-8), He init,
early stopping on validation MSE (patience 20).
Three input pipelines:
standard : raw frontal slice X(:, e), z-norm
invariant : [mean, std, min, max]_g X, z-norm
augmented : raw slice; training set replicated |G| times across orbit
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.featurizers import cyclic_angular_features, octahedral_features, stack_batch
class MLP(nn.Module):
def __init__(self, in_dim: int, hidden=(64, 32), out_dim: int = 1):
super().__init__()
layers = []
prev = in_dim
for h in hidden:
layers += [nn.Linear(prev, h), nn.ReLU()]
prev = h
layers.append(nn.Linear(prev, out_dim))
self.net = nn.Sequential(*layers)
# He init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
def forward(self, x):
return self.net(x)
def _build_inputs(X: torch.Tensor, mode: str, y: np.ndarray):
"""Build per-mode (X_train, y_train) representation. X is (N, n_f, n_g)."""
N, n_f, n_g = X.shape
if mode == "standard":
return X[:, :, 0].numpy(), y
if mode == "invariant":
mean = X.mean(dim=2)
std = X.std(dim=2)
mn = X.min(dim=2).values
mx = X.max(dim=2).values
return torch.cat([mean, std, mn, mx], dim=1).numpy(), y
if mode == "augmented":
# Replicate (N * n_g, n_f) inputs and (N * n_g,) targets
Xa = X.permute(0, 2, 1).reshape(N * n_g, n_f).numpy()
if y.ndim == 1:
ya = np.tile(y, n_g)
else:
ya = np.tile(y, (n_g, 1))
return Xa, ya
raise ValueError(mode)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--mode", choices=["standard", "invariant", "augmented"], required=True)
ap.add_argument("--target", required=True)
ap.add_argument("--qm9_dir", required=True)
ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic")
ap.add_argument("--n_rot", type=int, default=12)
ap.add_argument("--max_molecules", type=int, default=None)
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--epochs", type=int, default=300)
ap.add_argument("--lr", type=float, default=0.003)
ap.add_argument("--batch", type=int, default=256)
ap.add_argument("--patience", type=int, default=20)
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
out_dir = Path(args.out_dir) / f"mlp_{args.mode}" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
samples = [ds[i] for i in range(len(ds))]
train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed)
if args.group == "octahedral":
X = stack_batch(samples, octahedral_features)
else:
X = stack_batch(samples, cyclic_angular_features, n_rot=args.n_rot)
if args.target in PROPERTY_INDEX:
y = np.array([s.properties[PROPERTY_INDEX[args.target]] for s in samples])
else:
from train_starg import _build_target
y = _build_target(samples, args.target)
X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx]
y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx]
X_tr_in, y_tr_in = _build_inputs(X_tr, args.mode, y_tr)
# For validation/test always use the un-augmented frontal slice
X_va_in, y_va_in = _build_inputs(X_va, "standard" if args.mode == "augmented" else args.mode, y_va)
X_te_in, y_te_in = _build_inputs(X_te, "standard" if args.mode == "augmented" else args.mode, y_te)
mu = X_tr_in.mean(axis=0); sig = X_tr_in.std(axis=0) + 1e-8
X_tr_in = (X_tr_in - mu) / sig
X_va_in = (X_va_in - mu) / sig
X_te_in = (X_te_in - mu) / sig
out_dim = 1 if y_tr.ndim == 1 else y_tr.shape[1]
model = MLP(X_tr_in.shape[1], hidden=(64, 32), out_dim=out_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8)
Xt = torch.tensor(X_tr_in, dtype=torch.float32, device=device)
yt = torch.tensor(y_tr_in, dtype=torch.float32, device=device)
if yt.ndim == 1: yt = yt.unsqueeze(-1)
Xv = torch.tensor(X_va_in, dtype=torch.float32, device=device)
yv = torch.tensor(y_va_in, dtype=torch.float32, device=device)
if yv.ndim == 1: yv = yv.unsqueeze(-1)
Xe = torch.tensor(X_te_in, dtype=torch.float32, device=device)
best_val, best_state, wait = float("inf"), None, 0
t0 = time.time()
for ep in range(args.epochs):
perm = torch.randperm(Xt.shape[0])
for i in range(0, Xt.shape[0], args.batch):
idx = perm[i : i + args.batch]
opt.zero_grad()
pred = model(Xt[idx])
loss = ((pred - yt[idx]) ** 2).mean()
loss.backward()
opt.step()
with torch.no_grad():
val_pred = model(Xv)
val_loss = ((val_pred - yv) ** 2).mean().item()
if val_loss < best_val:
best_val, best_state, wait = val_loss, {k: v.detach().clone() for k, v in model.state_dict().items()}, 0
else:
wait += 1
if wait >= args.patience:
break
train_time = time.time() - t0
if best_state:
model.load_state_dict(best_state)
with torch.no_grad():
y_pred = model(Xe).cpu().numpy()
if y_te_in.ndim == 1:
y_pred = y_pred.flatten()
ss_res = ((y_te_in - y_pred) ** 2).sum()
ss_tot = ((y_te_in - y_te_in.mean(axis=0, keepdims=True)) ** 2).sum()
r2 = float(1 - ss_res / ss_tot)
rmse = float(np.sqrt(((y_te_in - y_pred) ** 2).mean()))
mae = float(np.abs(y_te_in - y_pred).mean())
result = {
"method": f"mlp_{args.mode}",
"target": args.target,
"group": args.group,
"seed": args.seed,
"n_total": len(ds),
"n_params": sum(p.numel() for p in model.parameters()),
"train_time_s": train_time,
"r2": r2, "rmse": rmse, "mae": mae,
"predictions": (
y_pred.tolist() if np.asarray(y_pred).ndim == 1 else y_pred.tolist()
),
"test_idx": test_idx.tolist(),
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] {args.mode} R2={r2:.4f} RMSE={rmse:.4g} -> {out_file}")
if __name__ == "__main__":
main()