tensor-group-sym / python / large_scale / data / dataset_adapter.py
dataset_adapter.py
Raw
"""Shared dataset adapter for the QM9 and QM7-X scalar-target experiments.

The three baseline training scripts (`train_baseline_mace.py`,
`train_baseline_schnet.py`, `train_baseline_e3nn.py`) all need to:

  1. load `samples` (each sample exposes `Z`, `coords`, plus a way to
     build the scalar target),
  2. build the per-sample target value `y` for a given target name,
  3. produce a reproducible (train, val, test) index split,
  4. know the element set the model should be configured for.

This module centralizes those four operations behind a uniform API so
the per-script branching on `--dataset {qm9,qm7x}` is one-liner-thin.

Targets:
  * `qm9`  : "gap", "alpha", "mu", "zpve" (scalar QM9 properties).
  * `qm7x` : "alpha_E", "alpha_T2", "alpha_iso" (octahedral irrep
             magnitudes of the molecular polarizability tensor).
"""

from __future__ import annotations

from typing import List, Tuple

import numpy as np


# Element set covered by each dataset (used to build z_tables / embeddings).
QM9_ELEMENTS = [1, 6, 7, 8, 9]
QM7X_ELEMENTS = [1, 6, 7, 8, 16, 17]   # H, C, N, O, S, Cl per QM7-X spec


# ---------------------------------------------------------------------------
# Load samples
# ---------------------------------------------------------------------------

def load_samples(dataset: str, data_dir: str, max_molecules=None):
    if dataset == "qm9":
        from data.qm9 import QM9Dataset
        ds = QM9Dataset(data_dir, max_molecules=max_molecules)
        return [ds[i] for i in range(len(ds))]
    if dataset == "qm7x":
        from data.qm7x import load_qm7x_equilibrium
        return load_qm7x_equilibrium(data_dir, max_molecules=max_molecules)
    raise ValueError(f"unknown dataset {dataset!r}")


# ---------------------------------------------------------------------------
# Build target array
# ---------------------------------------------------------------------------

def build_target(dataset: str, samples, target: str) -> np.ndarray:
    if dataset == "qm9":
        from data.qm9 import PROPERTY_INDEX
        if target in PROPERTY_INDEX:
            return np.array(
                [s.properties[PROPERTY_INDEX[target]] for s in samples],
                dtype=np.float64,
            )
        # fall through to the train_starg helper for derived targets
        from train_starg import _build_target
        return np.asarray(_build_target(samples, target), dtype=np.float64)
    if dataset == "qm7x":
        from data.qm7x import qm7x_target_array
        return qm7x_target_array(samples, target)
    raise ValueError(f"unknown dataset {dataset!r}")


# ---------------------------------------------------------------------------
# Split
# ---------------------------------------------------------------------------

def split_indices(
    dataset: str,
    n_total: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if dataset == "qm9":
        from data.qm9 import qm9_split
        return qm9_split(n_total, seed=seed)
    if dataset == "qm7x":
        from data.qm7x import qm7x_split
        return qm7x_split(n_total, seed=seed)
    raise ValueError(f"unknown dataset {dataset!r}")


# ---------------------------------------------------------------------------
# Element set
# ---------------------------------------------------------------------------

def element_set(dataset: str) -> List[int]:
    if dataset == "qm9":
        return list(QM9_ELEMENTS)
    if dataset == "qm7x":
        return list(QM7X_ELEMENTS)
    raise ValueError(f"unknown dataset {dataset!r}")