"""Shared dataset adapter for the QM9 and QM7-X scalar-target experiments.
The three baseline training scripts (`train_baseline_mace.py`,
`train_baseline_schnet.py`, `train_baseline_e3nn.py`) all need to:
1. load `samples` (each sample exposes `Z`, `coords`, plus a way to
build the scalar target),
2. build the per-sample target value `y` for a given target name,
3. produce a reproducible (train, val, test) index split,
4. know the element set the model should be configured for.
This module centralizes those four operations behind a uniform API so
the per-script branching on `--dataset {qm9,qm7x}` is one-liner-thin.
Targets:
* `qm9` : "gap", "alpha", "mu", "zpve" (scalar QM9 properties).
* `qm7x` : "alpha_E", "alpha_T2", "alpha_iso" (octahedral irrep
magnitudes of the molecular polarizability tensor).
"""
from __future__ import annotations
from typing import List, Tuple
import numpy as np
# Element set covered by each dataset (used to build z_tables / embeddings).
QM9_ELEMENTS = [1, 6, 7, 8, 9]
QM7X_ELEMENTS = [1, 6, 7, 8, 16, 17] # H, C, N, O, S, Cl per QM7-X spec
# ---------------------------------------------------------------------------
# Load samples
# ---------------------------------------------------------------------------
def load_samples(dataset: str, data_dir: str, max_molecules=None):
if dataset == "qm9":
from data.qm9 import QM9Dataset
ds = QM9Dataset(data_dir, max_molecules=max_molecules)
return [ds[i] for i in range(len(ds))]
if dataset == "qm7x":
from data.qm7x import load_qm7x_equilibrium
return load_qm7x_equilibrium(data_dir, max_molecules=max_molecules)
raise ValueError(f"unknown dataset {dataset!r}")
# ---------------------------------------------------------------------------
# Build target array
# ---------------------------------------------------------------------------
def build_target(dataset: str, samples, target: str) -> np.ndarray:
if dataset == "qm9":
from data.qm9 import PROPERTY_INDEX
if target in PROPERTY_INDEX:
return np.array(
[s.properties[PROPERTY_INDEX[target]] for s in samples],
dtype=np.float64,
)
# fall through to the train_starg helper for derived targets
from train_starg import _build_target
return np.asarray(_build_target(samples, target), dtype=np.float64)
if dataset == "qm7x":
from data.qm7x import qm7x_target_array
return qm7x_target_array(samples, target)
raise ValueError(f"unknown dataset {dataset!r}")
# ---------------------------------------------------------------------------
# Split
# ---------------------------------------------------------------------------
def split_indices(
dataset: str,
n_total: int,
seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if dataset == "qm9":
from data.qm9 import qm9_split
return qm9_split(n_total, seed=seed)
if dataset == "qm7x":
from data.qm7x import qm7x_split
return qm7x_split(n_total, seed=seed)
raise ValueError(f"unknown dataset {dataset!r}")
# ---------------------------------------------------------------------------
# Element set
# ---------------------------------------------------------------------------
def element_set(dataset: str) -> List[int]:
if dataset == "qm9":
return list(QM9_ELEMENTS)
if dataset == "qm7x":
return list(QM7X_ELEMENTS)
raise ValueError(f"unknown dataset {dataset!r}")