tensor-group-sym / python / large_scale / train_starg.py
train_starg.py
Raw
"""Unified ★_G training entry point.

Runs ★_G-SVD + Ridge or Neural ★_G on full QM9 (or the polarizability-tensor /
dipole-vector targets) and writes a JSON result file.

Usage:
    python train_starg.py --method ridge --target gap --qm9_dir /path/to/qm9 \
        --group cyclic --group_param 12 --seed 0 --out_dir results/

    python train_starg.py --method neural --target alpha_tensor \
        --qm9_dir /path/to/qm9 --group octahedral --seed 0 --out_dir results/
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import numpy as np
import torch
from sklearn.linear_model import Ridge, RidgeCV

from data.qm9 import QM9Dataset, qm9_split, PROPERTY_INDEX
from data.featurizers import (
    cyclic_angular_features,
    octahedral_features,
    coulomb_eig_extended_features,
    stack_batch,
)
from data.matlab_angular_features import matlab_angular_features
from starg_torch.algebra import GroupAlgebra
from starg_torch.features import extract_starg_features
from starg_torch.neural import NeuralStarG


def _build_group(group_type: str, param) -> GroupAlgebra:
    if group_type in ("cyclic", "dihedral"):
        return GroupAlgebra(group_type, int(param))
    if group_type == "octahedral":
        return GroupAlgebra("octahedral")
    raise ValueError(f"unsupported group type: {group_type}")


def _featurize(samples, group_type: str, n_rot: int,
               featurizer: str = "default",
               n_feat: int = 14) -> torch.Tensor:
    """Build the (N, n_feat, |G|) feature tensor for a list of samples.

    `featurizer` controls which molecule-level summary is used. All
    options preserve the (n_feat, |G|) shape contract, see CONTRIBUTING.md.

      - default          : group-appropriate angular features (14 rows)
                           [legacy reinvention; underspecified relative
                            to the MATLAB original, kept for
                            backwards-compat on already-saved results]
      - matlab_angular   : faithful port of MATLAB QM9_experiment's
                           angular_features() (cyclic/dihedral only).
                           Numerically equivalent to MATLAB at 5.7e-14
                           per element (see data/test_matlab_equivalence.py).
                           This is the published-MATLAB-pipeline analog;
                           use this when reproducing or improving on the
                           MATLAB-1k QM9 results.
      - cm_extended      : 14 angular + 29 Coulomb-matrix sorted eigenvalues
                           replicated as invariant rows (cyclic group only)
    """
    if featurizer == "matlab_angular":
        if group_type not in ("cyclic", "dihedral"):
            raise ValueError(
                "matlab_angular featurizer is wired for cyclic/dihedral "
                "groups (port of MATLAB QM9_experiment.angular_features)."
            )
        return stack_batch(
            samples, matlab_angular_features,
            n_rot=n_rot, n_feat_target=n_feat,
        )
    if featurizer == "cm_extended":
        if group_type not in ("cyclic", "dihedral"):
            raise ValueError(
                "cm_extended featurizer is currently wired for cyclic/dihedral "
                "groups; the 14 angular rows it extends are z-axis specific."
            )
        return stack_batch(samples, coulomb_eig_extended_features, n_rot=n_rot)
    if group_type in ("cyclic", "dihedral"):
        return stack_batch(samples, cyclic_angular_features, n_rot=n_rot)
    if group_type == "octahedral":
        return stack_batch(samples, octahedral_features)
    raise ValueError(group_type)


def _build_target(samples, target: str) -> np.ndarray:
    if target in PROPERTY_INDEX:
        return np.array([s.properties[PROPERTY_INDEX[target]] for s in samples])
    if target == "mu_vector":
        # Compute dipole vector from Mulliken charges
        out = np.zeros((len(samples), 3))
        for i, s in enumerate(samples):
            pos = s.coords - s.coords.mean(axis=0, keepdims=True)
            out[i] = (s.charges[:, None] * pos).sum(axis=0)
        return out
    if target == "alpha_tensor":
        # Approximate polarizability tensor proxy: use Mulliken-weighted
        # second moment as a 6-dim symmetric rank-2 target. The true α
        # tensor is not provided in QM9 (only the trace α is); we expose
        # this slot for QM7-X integration.
        out = np.zeros((len(samples), 6))
        for i, s in enumerate(samples):
            pos = s.coords - s.coords.mean(axis=0, keepdims=True)
            q = s.charges
            out[i, 0] = (q * pos[:, 0] ** 2).sum()
            out[i, 1] = (q * pos[:, 1] ** 2).sum()
            out[i, 2] = (q * pos[:, 2] ** 2).sum()
            out[i, 3] = (q * pos[:, 0] * pos[:, 1]).sum()
            out[i, 4] = (q * pos[:, 0] * pos[:, 2]).sum()
            out[i, 5] = (q * pos[:, 1] * pos[:, 2]).sum()
        return out
    raise ValueError(f"unknown target: {target}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--method", choices=["ridge", "neural"], required=True)
    ap.add_argument("--target", required=True,
                    help="property name (e.g. gap, mu, alpha) or vector/tensor target")
    ap.add_argument("--qm9_dir", required=True)
    ap.add_argument("--group", choices=["cyclic", "dihedral", "octahedral"], default="cyclic")
    ap.add_argument("--group_param", default=12)
    ap.add_argument("--n_rot", type=int, default=12)
    ap.add_argument("--n_feat", type=int, default=14,
                    help="number of feature rows per (atom, rotation). "
                         "MATLAB QM9_experiment defaults to 48; the legacy "
                         "Python default is 14. Use 48 with "
                         "--featurizer matlab_angular to reproduce the "
                         "published MATLAB-1k feature richness.")
    ap.add_argument("--max_molecules", type=int, default=None,
                    help="if set, subsample QM9 (for debugging). Default: full 134k")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--out_dir", default="results/")
    ap.add_argument("--epochs", type=int, default=300)
    ap.add_argument("--lr", type=float, default=0.003)
    ap.add_argument("--batch", type=int, default=256)
    ap.add_argument("--device", default="cuda")
    ap.add_argument("--hidden_widths", default="64 32",
                    help="space-separated widths for Neural-starG hidden layers; "
                         "default '64 32' matches the original manuscript spec. "
                         "Try '256 128' (~440k params) or '512 256 128' (~1.2M, "
                         "MACE-comparable) to test capacity scaling.")
    ap.add_argument("--featurizer", default="default",
                    choices=("default", "matlab_angular", "cm_extended"),
                    help="molecule-level feature recipe. 'default' uses the "
                         "group-appropriate angular projections (14 rows; "
                         "legacy reinvention). 'matlab_angular' is the "
                         "faithful port of MATLAB QM9_experiment's "
                         "angular_features (numerically equivalent to "
                         "MATLAB at 5.7e-14 per element; cyclic/dihedral). "
                         "'cm_extended' appends 29 Coulomb-matrix eigenvalues "
                         "as invariant rows for an information-richer "
                         "(43, |G|) tensor. See CONTRIBUTING.md for the "
                         "(n_feat, |G|) shape contract.")
    args = ap.parse_args()
    hidden_widths = [int(w) for w in args.hidden_widths.split()]

    out_dir = Path(args.out_dir) / f"starg_{args.method}" / args.target
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"seed{args.seed}.json"

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(f"[load] QM9 from {args.qm9_dir}")
    ds = QM9Dataset(args.qm9_dir, max_molecules=args.max_molecules)
    print(f"[load] {len(ds)} molecules")

    samples = [ds[i] for i in range(len(ds))]
    train_idx, val_idx, test_idx = qm9_split(len(ds), seed=args.seed)

    G = _build_group(args.group, args.group_param)
    G.to(device)

    print(f"[feat] computing {args.group} features (n_g={G.n})")
    t0 = time.time()
    X = _featurize(samples, args.group, args.n_rot,
                   featurizer=args.featurizer, n_feat=args.n_feat)  # (N, n_f, n_g)
    y = _build_target(samples, args.target)
    feat_time = time.time() - t0

    X_tr, X_va, X_te = X[train_idx], X[val_idx], X[test_idx]
    y_tr, y_va, y_te = y[train_idx], y[val_idx], y[test_idx]

    if args.method == "ridge":
        Phi_tr, norm = extract_starg_features(X_tr.to(device), G)
        Phi_va, _ = extract_starg_features(X_va.to(device), G, norm=norm)
        Phi_te, _ = extract_starg_features(X_te.to(device), G, norm=norm)
        Phi_tr = Phi_tr.cpu().numpy()
        Phi_va = Phi_va.cpu().numpy()
        Phi_te = Phi_te.cpu().numpy()
        # RidgeCV over the manuscript's grid; if y is multi-output, fit per column
        alphas = np.logspace(-3, 3, 7)
        if y_tr.ndim == 1:
            model = RidgeCV(alphas=alphas)
            model.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr, y_va]))
            y_pred = model.predict(Phi_te)
        else:
            preds = []
            for j in range(y_tr.shape[1]):
                m = RidgeCV(alphas=alphas)
                m.fit(np.vstack([Phi_tr, Phi_va]), np.concatenate([y_tr[:, j], y_va[:, j]]))
                preds.append(m.predict(Phi_te))
            y_pred = np.stack(preds, axis=1)
        n_params = Phi_tr.shape[1]
    else:
        # Neural ★_G, hidden widths configurable via --hidden_widths
        model = NeuralStarG(
            layer_sizes=[X.shape[1], *hidden_widths],
            G=G,
            output_dim=1 if y_tr.ndim == 1 else y_tr.shape[1],
        ).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=args.lr)
        Xt = X_tr.to(device); yt = torch.tensor(y_tr, dtype=torch.float32, device=device)
        Xv = X_va.to(device); yv = torch.tensor(y_va, dtype=torch.float32, device=device)
        if yt.ndim == 1:
            yt = yt.unsqueeze(-1); yv = yv.unsqueeze(-1)
        best_val, best_state, wait = float("inf"), None, 0
        n = Xt.shape[0]
        for ep in range(args.epochs):
            perm = torch.randperm(n)
            for i in range(0, n, args.batch):
                idx = perm[i : i + args.batch]
                opt.zero_grad()
                pred = model(Xt[idx])
                loss = ((pred - yt[idx]) ** 2).mean()
                loss.backward()
                opt.step()
            with torch.no_grad():
                val_pred = model(Xv)
                val_loss = ((val_pred - yv) ** 2).mean().item()
            if val_loss < best_val:
                best_val = val_loss
                best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
                wait = 0
            else:
                wait += 1
                if wait >= 20:
                    break
        if best_state is not None:
            model.load_state_dict(best_state)
        with torch.no_grad():
            y_pred = model(X_te.to(device)).cpu().numpy()
            if y_te.ndim == 1:
                y_pred = y_pred.flatten()
        n_params = sum(p.numel() for p in model.parameters())

    ss_res = ((y_te - y_pred) ** 2).sum()
    ss_tot = ((y_te - y_te.mean(axis=0, keepdims=True)) ** 2).sum()
    r2 = float(1 - ss_res / ss_tot)
    rmse = float(np.sqrt(((y_te - y_pred) ** 2).mean()))
    mae = float(np.abs(y_te - y_pred).mean())

    result = {
        "method": f"starg_{args.method}",
        "target": args.target,
        "group": args.group,
        "group_param": args.group_param,
        "featurizer": args.featurizer,
        "hidden_widths": args.hidden_widths if args.method == "neural" else None,
        "seed": args.seed,
        "n_train": len(train_idx),
        "n_val": len(val_idx),
        "n_test": len(test_idx),
        "n_total": len(ds),
        "n_params": int(n_params),
        "feat_time_s": feat_time,
        "r2": r2,
        "rmse": rmse,
        "mae": mae,
        # Per-test-molecule predictions for isomer_audit.py and
        # per_irrep_audit.py. ~150KB per JSON; cheap.
        "predictions": (
            y_pred.tolist() if y_pred.ndim == 1 else y_pred.tolist()
        ),
        "test_idx": test_idx.tolist(),
    }
    with open(out_file, "w") as fp:
        json.dump(result, fp, indent=2)
    print(f"[done] R2={r2:.4f}  RMSE={rmse:.4g}  -> {out_file}")


if __name__ == "__main__":
    main()