"""Full QM9 dataset loader. Loads the 134k QM9 .xyz files and the 19 quantum-chemical properties into PyTorch tensors. The loader is compatible with both the existing MATLAB pipeline (single .xyz directory) and PyTorch Geometric's QM9 wrapper. Mulliken charges are kept for the dipole-vector and polarizability targets used in the Wigner-Eckart and tensor-prediction experiments. The 12 standard targets are: 0: dipole moment magnitude (D) 1: isotropic polarizability α (ų) 2: HOMO energy (Ha) 3: LUMO energy (Ha) 4: HOMO-LUMO gap (Ha) 5: spatial extent (Ų) 6: ZPVE (Ha) 7: U0 internal energy at 0K (Ha) 8: U internal energy at 298K (Ha) 9: H enthalpy at 298K (Ha) 10: G free energy at 298K (Ha) 11: Cv heat capacity (cal/mol/K) """ from __future__ import annotations import os import re from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple import numpy as np import torch from torch.utils.data import Dataset PROPERTY_NAMES = [ "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv", ] PROPERTY_INDEX = {name: i for i, name in enumerate(PROPERTY_NAMES)} @dataclass class QM9Sample: coords: np.ndarray # (n_atoms, 3) Z: np.ndarray # atomic numbers (n_atoms,) charges: np.ndarray # Mulliken partial charges (n_atoms,) properties: np.ndarray # (12,) standard targets smiles: str = "" class QM9Dataset(Dataset): """Reads QM9 .xyz files and exposes per-molecule numpy arrays.""" def __init__( self, xyz_dir: str | Path, max_atoms: int = 29, max_molecules: Optional[int] = None, cache_path: Optional[str] = None, ): self.xyz_dir = Path(xyz_dir) self.max_atoms = max_atoms files = sorted(self.xyz_dir.glob("*.xyz")) if max_molecules is not None: files = files[:max_molecules] self.files = files if cache_path and Path(cache_path).exists(): cache = np.load(cache_path, allow_pickle=True) self.samples = list(cache["samples"]) else: self.samples = [self._parse(f) for f in self.files] if cache_path: np.savez(cache_path, samples=np.array(self.samples, dtype=object)) @staticmethod def _parse(path: Path) -> QM9Sample: with open(path, "r") as fp: lines = fp.readlines() n_atoms = int(lines[0].strip()) # Properties on line 2: "gdb 1 <12 floats>" prop_tokens = lines[1].split() # First two tokens are "gdb " props = np.array([float(t) for t in prop_tokens[2:14]]) coords = np.zeros((n_atoms, 3)) Z = np.zeros(n_atoms, dtype=int) charges = np.zeros(n_atoms) atom_to_z = {"H": 1, "C": 6, "N": 7, "O": 8, "F": 9} for i in range(n_atoms): tokens = lines[2 + i].split() symbol = tokens[0] Z[i] = atom_to_z[symbol] # XYZ files use *^ to denote scientific notation; replace xyz = [float(t.replace("*^", "e")) for t in tokens[1:4]] coords[i] = xyz if len(tokens) >= 5: charges[i] = float(tokens[4].replace("*^", "e")) # SMILES is optional metadata. Some converters emit an empty SMILES # line (just a tab); guard against split() returning []. _smiles_tokens = ( lines[3 + n_atoms].split() if len(lines) > 3 + n_atoms else [] ) smiles = _smiles_tokens[0] if _smiles_tokens else "" return QM9Sample(coords=coords, Z=Z, charges=charges, properties=props, smiles=smiles) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> QM9Sample: return self.samples[idx] def qm9_split( n: int, seed: int = 0, fractions: Tuple[float, float, float] = (0.7, 0.15, 0.15), ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: rng = np.random.default_rng(seed) perm = rng.permutation(n) n_tr = int(fractions[0] * n) n_va = int(fractions[1] * n) return perm[:n_tr], perm[n_tr : n_tr + n_va], perm[n_tr + n_va :]