tensor-group-sym / python / large_scale / data / pyg_to_xyz.py
pyg_to_xyz.py
Raw
"""Convert a torch_geometric QM9 dataset to per-molecule .xyz files.

QM9Dataset (data/qm9.py) consumes the per-molecule .xyz layout that
figshare ships. PyTorch Geometric's QM9 wrapper stores everything as
a single processed .pt file. This script bridges the two:

    pip install torch torch-geometric
    python -c "from torch_geometric.datasets import QM9; QM9(root='/path/to/qm9_pyg')"
    python pyg_to_xyz.py --pyg_root /path/to/qm9_pyg --xyz_dir /path/to/qm9/dsgdb9nsd

After this, point QM9Dataset at the xyz_dir and everything downstream
works.

Caveat: PyG's QM9 does not include the Mulliken partial charges that the
real figshare .xyz files carry as a 5th column. This converter writes
zeros for the charge column. That is sufficient for the scalar QM9
targets (gap, alpha, mu, zpve, ...) but means the dipole-vector
(mu_vector) and polarizability-tensor (alpha_tensor) targets, which are
built from Mulliken charges in train_starg.py, will be all-zero. If you
need those targets, do the manual figshare download instead, see
HANDOFF_CCC.md option B.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np


# QM9 atomic-number set
Z_TO_SYMBOL = {1: "H", 6: "C", 7: "N", 8: "O", 9: "F"}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--pyg_root", required=True,
                    help="root passed to torch_geometric.datasets.QM9")
    ap.add_argument("--xyz_dir", required=True,
                    help="output directory; one .xyz per molecule, named "
                         "dsgdb9nsd_<6-digit>.xyz to match the figshare layout.")
    ap.add_argument("--max_molecules", type=int, default=None)
    args = ap.parse_args()

    try:
        from torch_geometric.datasets import QM9
    except ImportError as e:
        raise SystemExit(
            "torch_geometric is not installed. Install with:\n"
            "    pip install torch-geometric\n"
            "(if you also need torch-scatter / torch-sparse for downstream PyG\n"
            "models, follow https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)"
        ) from e

    print(f"[load] PyG QM9 from {args.pyg_root}")
    ds = QM9(root=args.pyg_root)
    print(f"[load] {len(ds)} molecules")

    xyz_dir = Path(args.xyz_dir)
    xyz_dir.mkdir(parents=True, exist_ok=True)

    n = len(ds) if args.max_molecules is None else min(args.max_molecules, len(ds))
    for idx in range(n):
        d = ds[idx]
        z = d.z.cpu().numpy().astype(int)
        pos = d.pos.cpu().numpy()
        # PyG y has shape (1, 19); take the first 12 (standard QM9 props).
        # Indices 0..11 = (mu, alpha, homo, lumo, gap, R2, zpve, U0, U, H, G, Cv).
        props = d.y.cpu().numpy().flatten()[:12]
        n_atoms = len(z)
        # Mulliken charges absent in PyG; write zeros (see module docstring).
        charges = np.zeros(n_atoms)

        lines = [str(n_atoms)]
        lines.append(
            "gdb " + str(idx + 1) + " " + " ".join(f"{p:.6f}" for p in props)
        )
        for zi, xyz, q in zip(z, pos, charges):
            sym = Z_TO_SYMBOL.get(int(zi), "X")
            lines.append(f"{sym}\t{xyz[0]:.6f}\t{xyz[1]:.6f}\t{xyz[2]:.6f}\t{q:.6f}")
        # Vibrational frequencies row (parsed but unused by QM9Dataset)
        lines.append(" ".join("0.0" for _ in range(max(n_atoms - 6, 1))))
        # SMILES placeholder: prefer the real SMILES if PyG stored one, else
        # use a non-empty index-based placeholder so downstream parsers that
        # do `.split()[0]` get a token rather than IndexError on an empty line.
        sm = getattr(d, "smiles", None) or f"MOL{idx+1}"
        lines.append(f"{sm}\t{sm}")
        lines.append("InChI=1S/from-pyg")

        path = xyz_dir / f"dsgdb9nsd_{idx+1:06d}.xyz"
        path.write_text("\n".join(lines) + "\n")
        if (idx + 1) % 10000 == 0:
            print(f"  wrote {idx+1} / {n}")

    print(f"[done] wrote {n} .xyz files to {xyz_dir}")


if __name__ == "__main__":
    main()