"""Unified ★_G training entry point.
Runs ★_G-SVD + Ridge or Neural ★_G on full QM9 (or the polarizability-tensor /
dipole-vector targets) and writes a JSON result file.
Usage:
python train_starg.py --method ridge --target gap --qm9_dir /path/to/qm9 \
--group cyclic --group_param 12 --seed 0 --out_dir results/
python train_starg.py --method neural --target alpha_tensor \
--qm9_dir /path/to/qm9 --group octahedral --seed 0 --out_dir results/
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
from sklearn.linear_model import Ridge, RidgeCV
from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.featurizers import (
cyclic_angular_features,
octahedral_features,
coulomb_eig_extended_features,
stack_batch,
)
from data.matlab_angular_features import matlab_angular_features
from starg_torch.algebra import GroupAlgebra
from starg_torch.features import extract_starg_features
from starg_torch.neural import NeuralStarG
def _build_group(group_type: str, param) -> GroupAlgebra:
if group_type in ("cyclic", "dihedral"):
return GroupAlgebra(group_type, int(param))
if group_type == "octahedral":
return GroupAlgebra("octahedral")
raise ValueError(f"unsupported group type: {group_type}")
def _featurize(samples, group_type: str, n_rot: int,
featurizer: str = "default",
n_feat: int = 14) -> torch.Tensor:
"""Build the (N, n_feat, |G|) feature tensor for a list of samples.
`featurizer` controls which molecule-level summary is used. All
options preserve the (n_feat, |G|) shape contract, see CONTRIBUTING.md.
- default : group-appropriate angular features (14 rows)
[legacy reinvention; underspecified relative
to the MATLAB original, kept for
backwards-compat on already-saved results]
- matlab_angular : faithful port of MATLAB QM9_experiment's
angular_features() (cyclic/dihedral only).
Numerically equivalent to MATLAB at 5.7e-14
per element (see data/test_matlab_equivalence.py).
This is the published-MATLAB-pipeline analog;
use this when reproducing or improving on the
MATLAB-1k QM9 results.
- cm_extended : 14 angular + 29 Coulomb-matrix sorted eigenvalues
replicated as invariant rows (cyclic group only)
"""
if featurizer == "matlab_angular":
if group_type not in ("cyclic", "dihedral"):
raise ValueError(
"matlab_angular featurizer is wired for cyclic/dihedral "
"groups (port of MATLAB QM9_experiment.angular_features)."
)
return stack_batch(
samples, matlab_angular_features,
n_rot=n_rot, n_feat_target=n_feat,
)
if featurizer == "cm_extended":
if group_type not in ("cyclic", "dihedral"):
raise ValueError(
"cm_extended featurizer is currently wired for cyclic/dihedral "
"groups; the 14 angular rows it extends are z-axis specific."
)
return stack_batch(samples, coulomb_eig_extended_features, n_rot=n_rot)
if group_type in ("cyclic", "dihedral"):
return stack_batch(samples, cyclic_angular_features, n_rot=n_rot)
if group_type == "octahedral":
return stack_batch(samples, octahedral_features)
raise ValueError(group_type)
def _build_target(samples, target: str) -> np.ndarray:
if target in PROPERTY_INDEX:
return np.array([s.properties[PROPERTY_INDEX[target]] for s in samples])
if target == "mu_vector":
# Compute dipole vector from Mulliken charges
out = np.zeros((len(samples), 3))
for i, s in enumerate(samples):
pos = s.coords - s.coords.mean(axis=0, keepdims=True)
out[i] = (s.charges[:, None] * pos).sum(axis=0)
return out
if target == "alpha_tensor":
# Approximate polarizability tensor proxy: use Mulliken-weighted
# second moment as a 6-dim symmetric rank-2 target. The true α
# tensor is not provided in QM9 (only the trace α is); we expose
# this slot for QM7-X integration.
out = np.zeros((len(samples), 6))
for i, s in enumerate(samples):
pos = s.coords - s.coords.mean(axis=0, keepdims=True)
q = s.charges
out[i, 0] = (q * pos[:, 0] ** 2).sum()
out[i, 1] = (q * pos[:, 1] ** 2).sum()
out[i, 2] = (q * pos[:, 2] ** 2).sum()
out[i, 3] = (q * pos[:, 0] * pos[:, 1]).sum()
out[i, 4] = (q * pos[:, 0] * pos[:, 2]).sum()
out[i, 5] = (q * pos[:, 1] * pos[:, 2]).sum()
return out
raise ValueError(f"unknown target: {target}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--method", choices=["ridge", "neural"], required=True)
ap.add_argument("--target", required=True,
help="property name (e.g. gap, mu, alpha) or vector/tensor target")
ap.add_argument("--qm9_dir", required=True)
ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic")
ap.add_argument("--group_param", default=12)
ap.add_argument("--n_rot", type=int, default=12)
ap.add_argument("--n_feat", type=int, default=14,
help="number of feature rows per (atom, rotation). "
"MATLAB QM9_experiment defaults to 48; the legacy "
"Python default is 14. Use 48 with "
"--featurizer matlab_angular to reproduce the "
"published MATLAB-1k feature richness.")
ap.add_argument("--max_molecules", type=int, default=None,
help="if set, subsample QM9 (for debugging). Default: full 134k")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--out_dir", default="results/")
ap.add_argument("--epochs", type=int, default=300)
ap.add_argument("--lr", type=float, default=0.003)
ap.add_argument("--batch", type=int, default=256)
ap.add_argument("--device", default="cuda")
ap.add_argument("--hidden_widths", default="64 32",
help="space-separated widths for Neural-starG hidden layers; "
"default '64 32' matches the original manuscript spec. "
"Try '256 128' (~440k params) or '512 256 128' (~1.2M, "
"MACE-comparable) to test capacity scaling.")
ap.add_argument("--featurizer", default="default",
choices=("default", "matlab_angular", "cm_extended"),
help="molecule-level feature recipe. 'default' uses the "
"group-appropriate angular projections (14 rows; "
"legacy reinvention). 'matlab_angular' is the "
"faithful port of MATLAB QM9_experiment's "
"angular_features (numerically equivalent to "
"MATLAB at 5.7e-14 per element; cyclic/dihedral). "
"'cm_extended' appends 29 Coulomb-matrix eigenvalues "
"as invariant rows for an information-richer "
"(43, |G|) tensor. See CONTRIBUTING.md for the "
"(n_feat, |G|) shape contract.")
args = ap.parse_args()
hidden_widths = [int(w) for w in args.hidden_widths.split()]
out_dir = Path(args.out_dir) / f"starg_{args.method}" / args.target
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"seed{args.seed}.json"
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"[load] QM9 from {args.qm9_dir}")
ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
print(f"[load] {len(ds)} molecules")
samples = [ds[i] for i in range(len(ds))]
train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed)
G = _build_group(args.group, args.group_param)
G.to(device)
print(f"[feat] computing {args.group} features (n_g={G.n})")
t0 = time.time()
X = _featurize(samples, args.group, args.n_rot,
featurizer=args.featurizer, n_feat=args.n_feat) # (N, n_f, n_g)
y = _build_target(samples, args.target)
feat_time = time.time() - t0
X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx]
y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx]
if args.method == "ridge":
Phi_tr, norm = extract_starg_features(X_tr.to(device), G)
Phi_va, _ = extract_starg_features(X_va.to(device), G, norm=norm)
Phi_te, _ = extract_starg_features(X_te.to(device), G, norm=norm)
Phi_tr = Phi_tr.cpu().numpy()
Phi_va = Phi_va.cpu().numpy()
Phi_te = Phi_te.cpu().numpy()
# RidgeCV over the manuscript's grid; if y is multi-output, fit per column
alphas = np.logspace(-3, 3, 7)
if y_tr.ndim == 1:
model = RidgeCV(alphas=alphas)
model.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr, y_va]))
y_pred = model.predict(Phi_te)
else:
preds = []
for j in range(y_tr.shape[1]):
m = RidgeCV(alphas=alphas)
m.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr[:, j], y_va[:, j]]))
preds.append(m.predict(Phi_te))
y_pred = np.stack(preds, axis=1)
n_params = Phi_tr.shape[1]
else:
# Neural ★_G, hidden widths configurable via --hidden_widths
model = NeuralStarG(
layer_sizes=[X.shape[1], *hidden_widths],
G=G,
output_dim=1 if y_tr.ndim == 1 else y_tr.shape[1],
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
Xt = X_tr.to(device); yt = torch.tensor(y_tr, dtype=torch.float32, device=device)
Xv = X_va.to(device); yv = torch.tensor(y_va, dtype=torch.float32, device=device)
if yt.ndim == 1:
yt = yt.unsqueeze(-1); yv = yv.unsqueeze(-1)
best_val, best_state, wait = float("inf"), None, 0
n = Xt.shape[0]
for ep in range(args.epochs):
perm = torch.randperm(n)
for i in range(0, n, args.batch):
idx = perm[i : i + args.batch]
opt.zero_grad()
pred = model(Xt[idx])
loss = ((pred - yt[idx]) ** 2).mean()
loss.backward()
opt.step()
with torch.no_grad():
val_pred = model(Xv)
val_loss = ((val_pred - yv) ** 2).mean().item()
if val_loss < best_val:
best_val = val_loss
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
wait = 0
else:
wait += 1
if wait >= 20:
break
if best_state is not None:
model.load_state_dict(best_state)
with torch.no_grad():
y_pred = model(X_te.to(device)).cpu().numpy()
if y_te.ndim == 1:
y_pred = y_pred.flatten()
n_params = sum(p.numel() for p in model.parameters())
ss_res = ((y_te - y_pred) ** 2).sum()
ss_tot = ((y_te - y_te.mean(axis=0, keepdims=True)) ** 2).sum()
r2 = float(1 - ss_res / ss_tot)
rmse = float(np.sqrt(((y_te - y_pred) ** 2).mean()))
mae = float(np.abs(y_te - y_pred).mean())
result = {
"method": f"starg_{args.method}",
"target": args.target,
"group": args.group,
"group_param": args.group_param,
"featurizer": args.featurizer,
"hidden_widths": args.hidden_widths if args.method == "neural" else None,
"seed": args.seed,
"n_train": len(train_idx),
"n_val": len(val_idx),
"n_test": len(test_idx),
"n_total": len(ds),
"n_params": int(n_params),
"feat_time_s": feat_time,
"r2": r2,
"rmse": rmse,
"mae": mae,
# Per-test-molecule predictions for isomer_audit.py and
# per_irrep_audit.py. ~150KB per JSON; cheap.
"predictions": (
y_pred.tolist() if y_pred.ndim == 1 else y_pred.tolist()
),
"test_idx": test_idx.tolist(),
}
with open(out_file, "w") as fp:
json.dump(result, fp, indent=2)
print(f"[done] R2={r2:.4f} RMSE={rmse:.4g} -> {out_file}")
if __name__ == "__main__":
main()