"""Unified ★_G training entry point. Runs ★_G-SVD + Ridge or Neural ★_G on full QM9 (or the polarizability-tensor / dipole-vector targets) and writes a JSON result file. Usage: python train_starg.py --method ridge --target gap --qm9_dir /path/to/qm9 \ --group cyclic --group_param 12 --seed 0 --out_dir results/ python train_starg.py --method neural --target alpha_tensor \ --qm9_dir /path/to/qm9 --group octahedral --seed 0 --out_dir results/ """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch from sklearn.linear_model import Ridge, RidgeCV from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX from data.featurizers import ( cyclic_angular_features, octahedral_features, coulomb_eig_extended_features, stack_batch, ) from data.matlab_angular_features import matlab_angular_features from starg_torch.algebra import GroupAlgebra from starg_torch.features import extract_starg_features from starg_torch.neural import NeuralStarG def _build_group(group_type: str, param) -> GroupAlgebra: if group_type in ("cyclic", "dihedral"): return GroupAlgebra(group_type, int(param)) if group_type == "octahedral": return GroupAlgebra("octahedral") raise ValueError(f"unsupported group type: {group_type}") def _featurize(samples, group_type: str, n_rot: int, featurizer: str = "default", n_feat: int = 14) -> torch.Tensor: """Build the (N, n_feat, |G|) feature tensor for a list of samples. `featurizer` controls which molecule-level summary is used. All options preserve the (n_feat, |G|) shape contract, see CONTRIBUTING.md. - default : group-appropriate angular features (14 rows) [legacy reinvention; underspecified relative to the MATLAB original, kept for backwards-compat on already-saved results] - matlab_angular : faithful port of MATLAB QM9_experiment's angular_features() (cyclic/dihedral only). Numerically equivalent to MATLAB at 5.7e-14 per element (see data/test_matlab_equivalence.py). This is the published-MATLAB-pipeline analog; use this when reproducing or improving on the MATLAB-1k QM9 results. - cm_extended : 14 angular + 29 Coulomb-matrix sorted eigenvalues replicated as invariant rows (cyclic group only) """ if featurizer == "matlab_angular": if group_type not in ("cyclic", "dihedral"): raise ValueError( "matlab_angular featurizer is wired for cyclic/dihedral " "groups (port of MATLAB QM9_experiment.angular_features)." ) return stack_batch( samples, matlab_angular_features, n_rot=n_rot, n_feat_target=n_feat, ) if featurizer == "cm_extended": if group_type not in ("cyclic", "dihedral"): raise ValueError( "cm_extended featurizer is currently wired for cyclic/dihedral " "groups; the 14 angular rows it extends are z-axis specific." ) return stack_batch(samples, coulomb_eig_extended_features, n_rot=n_rot) if group_type in ("cyclic", "dihedral"): return stack_batch(samples, cyclic_angular_features, n_rot=n_rot) if group_type == "octahedral": return stack_batch(samples, octahedral_features) raise ValueError(group_type) def _build_target(samples, target: str) -> np.ndarray: if target in PROPERTY_INDEX: return np.array([s.properties[PROPERTY_INDEX[target]] for s in samples]) if target == "mu_vector": # Compute dipole vector from Mulliken charges out = np.zeros((len(samples), 3)) for i, s in enumerate(samples): pos = s.coords - s.coords.mean(axis=0, keepdims=True) out[i] = (s.charges[:, None] * pos).sum(axis=0) return out if target == "alpha_tensor": # Approximate polarizability tensor proxy: use Mulliken-weighted # second moment as a 6-dim symmetric rank-2 target. The true α # tensor is not provided in QM9 (only the trace α is); we expose # this slot for QM7-X integration. out = np.zeros((len(samples), 6)) for i, s in enumerate(samples): pos = s.coords - s.coords.mean(axis=0, keepdims=True) q = s.charges out[i, 0] = (q * pos[:, 0] ** 2).sum() out[i, 1] = (q * pos[:, 1] ** 2).sum() out[i, 2] = (q * pos[:, 2] ** 2).sum() out[i, 3] = (q * pos[:, 0] * pos[:, 1]).sum() out[i, 4] = (q * pos[:, 0] * pos[:, 2]).sum() out[i, 5] = (q * pos[:, 1] * pos[:, 2]).sum() return out raise ValueError(f"unknown target: {target}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--method", choices=["ridge", "neural"], required=True) ap.add_argument("--target", required=True, help="property name (e.g. gap, mu, alpha) or vector/tensor target") ap.add_argument("--qm9_dir", required=True) ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic") ap.add_argument("--group_param", default=12) ap.add_argument("--n_rot", type=int, default=12) ap.add_argument("--n_feat", type=int, default=14, help="number of feature rows per (atom, rotation). " "MATLAB QM9_experiment defaults to 48; the legacy " "Python default is 14. Use 48 with " "--featurizer matlab_angular to reproduce the " "published MATLAB-1k feature richness.") ap.add_argument("--max_molecules", type=int, default=None, help="if set, subsample QM9 (for debugging). Default: full 134k") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--out_dir", default="results/") ap.add_argument("--epochs", type=int, default=300) ap.add_argument("--lr", type=float, default=0.003) ap.add_argument("--batch", type=int, default=256) ap.add_argument("--device", default="cuda") ap.add_argument("--hidden_widths", default="64 32", help="space-separated widths for Neural-starG hidden layers; " "default '64 32' matches the original manuscript spec. " "Try '256 128' (~440k params) or '512 256 128' (~1.2M, " "MACE-comparable) to test capacity scaling.") ap.add_argument("--featurizer", default="default", choices=("default", "matlab_angular", "cm_extended"), help="molecule-level feature recipe. 'default' uses the " "group-appropriate angular projections (14 rows; " "legacy reinvention). 'matlab_angular' is the " "faithful port of MATLAB QM9_experiment's " "angular_features (numerically equivalent to " "MATLAB at 5.7e-14 per element; cyclic/dihedral). " "'cm_extended' appends 29 Coulomb-matrix eigenvalues " "as invariant rows for an information-richer " "(43, |G|) tensor. See CONTRIBUTING.md for the " "(n_feat, |G|) shape contract.") args = ap.parse_args() hidden_widths = [int(w) for w in args.hidden_widths.split()] out_dir = Path(args.out_dir) / f"starg_{args.method}" / args.target out_dir.mkdir(parents=True, exist_ok=True) out_file = out_dir / f"seed{args.seed}.json" torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"[load] QM9 from {args.qm9_dir}") ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules) print(f"[load] {len(ds)} molecules") samples = [ds[i] for i in range(len(ds))] train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed) G = _build_group(args.group, args.group_param) G.to(device) print(f"[feat] computing {args.group} features (n_g={G.n})") t0 = time.time() X = _featurize(samples, args.group, args.n_rot, featurizer=args.featurizer, n_feat=args.n_feat) # (N, n_f, n_g) y = _build_target(samples, args.target) feat_time = time.time() - t0 X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx] y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx] if args.method == "ridge": Phi_tr, norm = extract_starg_features(X_tr.to(device), G) Phi_va, _ = extract_starg_features(X_va.to(device), G, norm=norm) Phi_te, _ = extract_starg_features(X_te.to(device), G, norm=norm) Phi_tr = Phi_tr.cpu().numpy() Phi_va = Phi_va.cpu().numpy() Phi_te = Phi_te.cpu().numpy() # RidgeCV over the manuscript's grid; if y is multi-output, fit per column alphas = np.logspace(-3, 3, 7) if y_tr.ndim == 1: model = RidgeCV(alphas=alphas) model.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr, y_va])) y_pred = model.predict(Phi_te) else: preds = [] for j in range(y_tr.shape[1]): m = RidgeCV(alphas=alphas) m.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr[:, j], y_va[:, j]])) preds.append(m.predict(Phi_te)) y_pred = np.stack(preds, axis=1) n_params = Phi_tr.shape[1] else: # Neural ★_G, hidden widths configurable via --hidden_widths model = NeuralStarG( layer_sizes=[X.shape[1], *hidden_widths], G=G, output_dim=1 if y_tr.ndim == 1 else y_tr.shape[1], ).to(device) opt = torch.optim.Adam(model.parameters(), lr=args.lr) Xt = X_tr.to(device); yt = torch.tensor(y_tr, dtype=torch.float32, device=device) Xv = X_va.to(device); yv = torch.tensor(y_va, dtype=torch.float32, device=device) if yt.ndim == 1: yt = yt.unsqueeze(-1); yv = yv.unsqueeze(-1) best_val, best_state, wait = float("inf"), None, 0 n = Xt.shape[0] for ep in range(args.epochs): perm = torch.randperm(n) for i in range(0, n, args.batch): idx = perm[i : i + args.batch] opt.zero_grad() pred = model(Xt[idx]) loss = ((pred - yt[idx]) ** 2).mean() loss.backward() opt.step() with torch.no_grad(): val_pred = model(Xv) val_loss = ((val_pred - yv) ** 2).mean().item() if val_loss < best_val: best_val = val_loss best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} wait = 0 else: wait += 1 if wait >= 20: break if best_state is not None: model.load_state_dict(best_state) with torch.no_grad(): y_pred = model(X_te.to(device)).cpu().numpy() if y_te.ndim == 1: y_pred = y_pred.flatten() n_params = sum(p.numel() for p in model.parameters()) ss_res = ((y_te - y_pred) ** 2).sum() ss_tot = ((y_te - y_te.mean(axis=0, keepdims=True)) ** 2).sum() r2 = float(1 - ss_res / ss_tot) rmse = float(np.sqrt(((y_te - y_pred) ** 2).mean())) mae = float(np.abs(y_te - y_pred).mean()) result = { "method": f"starg_{args.method}", "target": args.target, "group": args.group, "group_param": args.group_param, "featurizer": args.featurizer, "hidden_widths": args.hidden_widths if args.method == "neural" else None, "seed": args.seed, "n_train": len(train_idx), "n_val": len(val_idx), "n_test": len(test_idx), "n_total": len(ds), "n_params": int(n_params), "feat_time_s": feat_time, "r2": r2, "rmse": rmse, "mae": mae, # Per-test-molecule predictions for isomer_audit.py and # per_irrep_audit.py. ~150KB per JSON; cheap. "predictions": ( y_pred.tolist() if y_pred.ndim == 1 else y_pred.tolist() ), "test_idx": test_idx.tolist(), } with open(out_file, "w") as fp: json.dump(result, fp, indent=2) print(f"[done] R2={r2:.4f} RMSE={rmse:.4g} -> {out_file}") if __name__ == "__main__": main()