"""Convert a torch_geometric QM9 dataset to per-molecule .xyz files.
QM9Dataset (data/qm9.py) consumes the per-molecule .xyz layout that
figshare ships. PyTorch Geometric's QM9 wrapper stores everything as
a single processed .pt file. This script bridges the two:
pip install torch torch-geometric
python -c "from torch_geometric.datasets import QM9; QM9(root='/path/to/qm9_pyg')"
python pyg_to_xyz.py --pyg_root /path/to/qm9_pyg --xyz_dir /path/to/qm9/dsgdb9nsd
After this, point QM9Dataset at the xyz_dir and everything downstream
works.
Caveat: PyG's QM9 does not include the Mulliken partial charges that the
real figshare .xyz files carry as a 5th column. This converter writes
zeros for the charge column. That is sufficient for the scalar QM9
targets (gap, alpha, mu, zpve, ...) but means the dipole-vector
(mu_vector) and polarizability-tensor (alpha_tensor) targets, which are
built from Mulliken charges in train_starg.py, will be all-zero. If you
need those targets, do the manual figshare download instead, see
HANDOFF_CCC.md option B.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
# QM9 atomic-number set
Z_TO_SYMBOL = {1: "H", 6: "C", 7: "N", 8: "O", 9: "F"}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--pyg_root", required=True,
help="root passed to torch_geometric.datasets.QM9")
ap.add_argument("--xyz_dir", required=True,
help="output directory; one .xyz per molecule, named "
"dsgdb9nsd_<6-digit>.xyz to match the figshare layout.")
ap.add_argument("--max_molecules", type=int, default=None)
args = ap.parse_args()
try:
from torch_geometric.datasets import QM9
except ImportError as e:
raise SystemExit(
"torch_geometric is not installed. Install with:\n"
" pip install torch-geometric\n"
"(if you also need torch-scatter / torch-sparse for downstream PyG\n"
"models, follow https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)"
) from e
print(f"[load] PyG QM9 from {args.pyg_root}")
ds = QM9(root=args.pyg_root)
print(f"[load] {len(ds)} molecules")
xyz_dir = Path(args.xyz_dir)
xyz_dir.mkdir(parents=True, exist_ok=True)
n = len(ds) if args.max_molecules is None else min(args.max_molecules, len(ds))
for idx in range(n):
d = ds[idx]
z = d.z.cpu().numpy().astype(int)
pos = d.pos.cpu().numpy()
# PyG y has shape (1, 19); take the first 12 (standard QM9 props).
# Indices 0..11 = (mu, alpha, homo, lumo, gap, R2, zpve, U0, U, H, G, Cv).
props = d.y.cpu().numpy().flatten()[:12]
n_atoms = len(z)
# Mulliken charges absent in PyG; write zeros (see module docstring).
charges = np.zeros(n_atoms)
lines = [str(n_atoms)]
lines.append(
"gdb " + str(idx + 1) + " " + " ".join(f"{p:.6f}" for p in props)
)
for zi, xyz, q in zip(z, pos, charges):
sym = Z_TO_SYMBOL.get(int(zi), "X")
lines.append(f"{sym}\t{xyz[0]:.6f}\t{xyz[1]:.6f}\t{xyz[2]:.6f}\t{q:.6f}")
# Vibrational frequencies row (parsed but unused by QM9Dataset)
lines.append(" ".join("0.0" for _ in range(max(n_atoms - 6, 1))))
# SMILES placeholder: prefer the real SMILES if PyG stored one, else
# use a non-empty index-based placeholder so downstream parsers that
# do `.split()[0]` get a token rather than IndexError on an empty line.
sm = getattr(d, "smiles", None) or f"MOL{idx+1}"
lines.append(f"{sm}\t{sm}")
lines.append("InChI=1S/from-pyg")
path = xyz_dir / f"dsgdb9nsd_{idx+1:06d}.xyz"
path.write_text("\n".join(lines) + "\n")
if (idx + 1) % 10000 == 0:
print(f" wrote {idx+1} / {n}")
print(f"[done] wrote {n} .xyz files to {xyz_dir}")
if __name__ == "__main__":
main()