tensor-group-sym / python / large_scale / data / qm7x.py
qm7x.py
Raw
"""QM7-X equilibrium-conformation loader for the polarizability experiment.

QM7-X HDF5 schema: each file (1000.hdf5, 2000.hdf5, ...) contains
molecule-id groups at the top level; each molecule has one or more
conformation-id sub-groups. The equilibrium structure carries a "-opt"
suffix in the conformation id, which we use as the per-molecule selector.

This loader is deliberately independent of any external code (no imports
from sibling repositories, no reuse of pre-existing QM7-X experiment
scripts). It implements only what `python/large_scale/` needs:

  * `load_qm7x_equilibrium(data_dir, max_molecules=None)` returns a list
    of `QM7XSample` dataclasses with `Z`, `coords`, `pol`.
  * `decompose_polarizability(pol)` returns the three irrep-norm scalar
    targets used by the cross-selectivity comparison: `alpha_iso` (A_1g
    trace), `alpha_E` (E_g traceless-diagonal anisotropy), and `alpha_T2`
    (T_2g off-diagonal magnitude).
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import h5py
import numpy as np


@dataclass
class QM7XSample:
    Z: np.ndarray            # (n_atoms,) atomic numbers
    coords: np.ndarray       # (n_atoms, 3) Cartesian coordinates
    pol: np.ndarray          # (3, 3) molecular polarizability mTPOL


def load_qm7x_equilibrium(
    data_dir: str,
    max_molecules: Optional[int] = None,
) -> List[QM7XSample]:
    """Stream every equilibrium ("-opt") conformation from QM7-X HDF5 files.

    Args:
        data_dir : directory containing one or more *.hdf5 files.
        max_molecules : optional cap on samples returned (for smoke tests).

    Returns:
        list of QM7XSample, one per equilibrium conformation found.
    """
    data_dir = Path(data_dir)
    samples: List[QM7XSample] = []
    for hdf5_path in sorted(data_dir.glob("*.hdf5")):
        with h5py.File(str(hdf5_path), "r") as f:
            for mol_id in f.keys():
                for conf_id in f[mol_id].keys():
                    if not conf_id.endswith("-opt"):
                        continue
                    g = f[mol_id][conf_id]
                    pol_raw = np.asarray(g["mTPOL"])
                    # mTPOL is stored as a flattened 9-vector in row-major
                    # order; reshape to the canonical (3, 3) symmetric form.
                    if pol_raw.shape == (9,):
                        pol_mat = pol_raw.reshape(3, 3)
                    elif pol_raw.shape == (3, 3):
                        pol_mat = pol_raw
                    else:
                        raise ValueError(
                            f"unexpected mTPOL shape {pol_raw.shape} "
                            f"in {hdf5_path}::{mol_id}/{conf_id}"
                        )
                    samples.append(
                        QM7XSample(
                            Z=np.asarray(g["atNUM"]),
                            coords=np.asarray(g["atXYZ"]),
                            pol=pol_mat,
                        )
                    )
                    if max_molecules and len(samples) >= max_molecules:
                        return samples
    return samples


def decompose_polarizability(pol: np.ndarray) -> dict:
    """Octahedral irrep decomposition of a 3×3 symmetric polarizability tensor.

    Splits a (3, 3) tensor into:
      * `alpha_iso`  : A_1g component (1/3 trace)
      * `alpha_E`    : E_g magnitude (norm of the two diagonal-anisotropy
                       components in real-spherical normalization)
      * `alpha_T2`   : T_2g magnitude (norm of the three off-diagonal
                       components, factor sqrt(2) from real-spherical
                       normalization)

    Returns a dict so callers can pick the required scalar target.
    """
    pol = np.asarray(pol, dtype=np.float64)
    alpha_iso = float(np.trace(pol) / 3.0)
    traceless = pol - alpha_iso * np.eye(3)
    sym = (traceless + traceless.T) / 2.0
    e1 = (sym[0, 0] - sym[1, 1]) / np.sqrt(2.0)
    e2 = (2.0 * sym[2, 2] - sym[0, 0] - sym[1, 1]) / np.sqrt(6.0)
    alpha_E = float(np.sqrt(e1 ** 2 + e2 ** 2))
    alpha_T2 = float(
        np.sqrt(sym[0, 1] ** 2 + sym[0, 2] ** 2 + sym[1, 2] ** 2)
        * np.sqrt(2.0)
    )
    return {"alpha_iso": alpha_iso, "alpha_E": alpha_E, "alpha_T2": alpha_T2}


def qm7x_target_array(samples: List[QM7XSample], target: str) -> np.ndarray:
    """Build a (N,) array of one of the polarizability scalar targets."""
    if target not in ("alpha_iso", "alpha_E", "alpha_T2"):
        raise ValueError(f"unknown qm7x target {target!r}; "
                         f"expected alpha_iso | alpha_E | alpha_T2")
    return np.array(
        [decompose_polarizability(s.pol)[target] for s in samples],
        dtype=np.float64,
    )


def qm7x_split(n_total: int, seed: int = 42, frac_train: float = 0.6,
               frac_val: float = 0.2):
    """Reproducible 60/20/20 per-molecule split (per the QM7-X protocol)."""
    rng = np.random.default_rng(seed)
    perm = rng.permutation(n_total)
    n_train = int(frac_train * n_total)
    n_val = int(frac_val * n_total)
    return (
        perm[:n_train],
        perm[n_train:n_train + n_val],
        perm[n_train + n_val:],
    )