"""Shared dataset adapter for the QM9 and QM7-X scalar-target experiments. The three baseline training scripts (`train_baseline_mace.py`, `train_baseline_schnet.py`, `train_baseline_e3nn.py`) all need to: 1. load `samples` (each sample exposes `Z`, `coords`, plus a way to build the scalar target), 2. build the per-sample target value `y` for a given target name, 3. produce a reproducible (train, val, test) index split, 4. know the element set the model should be configured for. This module centralizes those four operations behind a uniform API so the per-script branching on `--dataset {qm9,qm7x}` is one-liner-thin. Targets: * `qm9` : "gap", "alpha", "mu", "zpve" (scalar QM9 properties). * `qm7x` : "alpha_E", "alpha_T2", "alpha_iso" (octahedral irrep magnitudes of the molecular polarizability tensor). """ from __future__ import annotations from typing import List, Tuple import numpy as np # Element set covered by each dataset (used to build z_tables / embeddings). QM9_ELEMENTS = [1, 6, 7, 8, 9] QM7X_ELEMENTS = [1, 6, 7, 8, 16, 17] # H, C, N, O, S, Cl per QM7-X spec # --------------------------------------------------------------------------- # Load samples # --------------------------------------------------------------------------- def load_samples(dataset: str, data_dir: str, max_molecules=None): if dataset == "qm9": from data.qm9 import QM9Dataset ds = QM9Dataset(data_dir, max_molecules=max_molecules) return [ds[i] for i in range(len(ds))] if dataset == "qm7x": from data.qm7x import load_qm7x_equilibrium return load_qm7x_equilibrium(data_dir, max_molecules=max_molecules) raise ValueError(f"unknown dataset {dataset!r}") # --------------------------------------------------------------------------- # Build target array # --------------------------------------------------------------------------- def build_target(dataset: str, samples, target: str) -> np.ndarray: if dataset == "qm9": from data.qm9 import PROPERTY_INDEX if target in PROPERTY_INDEX: return np.array( [s.properties[PROPERTY_INDEX[target]] for s in samples], dtype=np.float64, ) # fall through to the train_starg helper for derived targets from train_starg import _build_target return np.asarray(_build_target(samples, target), dtype=np.float64) if dataset == "qm7x": from data.qm7x import qm7x_target_array return qm7x_target_array(samples, target) raise ValueError(f"unknown dataset {dataset!r}") # --------------------------------------------------------------------------- # Split # --------------------------------------------------------------------------- def split_indices( dataset: str, n_total: int, seed: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if dataset == "qm9": from data.qm9 import qm9_split return qm9_split(n_total, seed=seed) if dataset == "qm7x": from data.qm7x import qm7x_split return qm7x_split(n_total, seed=seed) raise ValueError(f"unknown dataset {dataset!r}") # --------------------------------------------------------------------------- # Element set # --------------------------------------------------------------------------- def element_set(dataset: str) -> List[int]: if dataset == "qm9": return list(QM9_ELEMENTS) if dataset == "qm7x": return list(QM7X_ELEMENTS) raise ValueError(f"unknown dataset {dataset!r}")