"""Full QM9 dataset loader.
Loads the 134k QM9 .xyz files and the 19 quantum-chemical properties
into PyTorch tensors. The loader is compatible with both the existing
MATLAB pipeline (single .xyz directory) and PyTorch Geometric's QM9
wrapper. Mulliken charges are kept for the dipole-vector and polarizability
targets used in the Wigner-Eckart and tensor-prediction experiments.
The 12 standard targets are:
0: dipole moment magnitude (D)
1: isotropic polarizability α (ų)
2: HOMO energy (Ha)
3: LUMO energy (Ha)
4: HOMO-LUMO gap (Ha)
5: <R²> spatial extent (Ų)
6: ZPVE (Ha)
7: U0 internal energy at 0K (Ha)
8: U internal energy at 298K (Ha)
9: H enthalpy at 298K (Ha)
10: G free energy at 298K (Ha)
11: Cv heat capacity (cal/mol/K)
"""
from __future__ import annotations
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
PROPERTY_NAMES = [
"mu", "alpha", "homo", "lumo", "gap", "r2", "zpve",
"U0", "U", "H", "G", "Cv",
]
PROPERTY_INDEX = {name: i for i, name in enumerate(PROPERTY_NAMES)}
@dataclass
class QM9Sample:
coords: np.ndarray # (n_atoms, 3)
Z: np.ndarray # atomic numbers (n_atoms,)
charges: np.ndarray # Mulliken partial charges (n_atoms,)
properties: np.ndarray # (12,) standard targets
smiles: str = ""
class QM9Dataset(Dataset):
"""Reads QM9 .xyz files and exposes per-molecule numpy arrays."""
def __init__(
self,
xyz_dir: str | Path,
max_atoms: int = 29,
max_molecules: Optional[int] = None,
cache_path: Optional[str] = None,
):
self.xyz_dir = Path(xyz_dir)
self.max_atoms = max_atoms
files = sorted(self.xyz_dir.glob("*.xyz"))
if max_molecules is not None:
files = files[:max_molecules]
self.files = files
if cache_path and Path(cache_path).exists():
cache = np.load(cache_path, allow_pickle=True)
self.samples = list(cache["samples"])
else:
self.samples = [self._parse(f) for f in self.files]
if cache_path:
np.savez(cache_path, samples=np.array(self.samples, dtype=object))
@staticmethod
def _parse(path: Path) -> QM9Sample:
with open(path, "r") as fp:
lines = fp.readlines()
n_atoms = int(lines[0].strip())
# Properties on line 2: "gdb 1 <12 floats>"
prop_tokens = lines[1].split()
# First two tokens are "gdb <index>"
props = np.array([float(t) for t in prop_tokens[2:14]])
coords = np.zeros((n_atoms, 3))
Z = np.zeros(n_atoms, dtype=int)
charges = np.zeros(n_atoms)
atom_to_z = {"H": 1, "C": 6, "N": 7, "O": 8, "F": 9}
for i in range(n_atoms):
tokens = lines[2 + i].split()
symbol = tokens[0]
Z[i] = atom_to_z[symbol]
# XYZ files use *^ to denote scientific notation; replace
xyz = [float(t.replace("*^", "e")) for t in tokens[1:4]]
coords[i] = xyz
if len(tokens) >= 5:
charges[i] = float(tokens[4].replace("*^", "e"))
# SMILES is optional metadata. Some converters emit an empty SMILES
# line (just a tab); guard against split() returning [].
_smiles_tokens = (
lines[3 + n_atoms].split() if len(lines) > 3 + n_atoms else []
)
smiles = _smiles_tokens[0] if _smiles_tokens else ""
return QM9Sample(coords=coords, Z=Z, charges=charges, properties=props, smiles=smiles)
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> QM9Sample:
return self.samples[idx]
def qm9_split(
n: int,
seed: int = 0,
fractions: Tuple[float, float, float] = (0.7, 0.15, 0.15),
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
rng = np.random.default_rng(seed)
perm = rng.permutation(n)
n_tr = int(fractions[0] * n)
n_va = int(fractions[1] * n)
return perm[:n_tr], perm[n_tr : n_tr + n_va], perm[n_tr + n_va :]