"""QM7-X equilibrium-conformation loader for the polarizability experiment.
QM7-X HDF5 schema: each file (1000.hdf5, 2000.hdf5, ...) contains
molecule-id groups at the top level; each molecule has one or more
conformation-id sub-groups. The equilibrium structure carries a "-opt"
suffix in the conformation id, which we use as the per-molecule selector.
This loader is deliberately independent of any external code (no imports
from sibling repositories, no reuse of pre-existing QM7-X experiment
scripts). It implements only what `python/large_scale/` needs:
* `load_qm7x_equilibrium(data_dir, max_molecules=None)` returns a list
of `QM7XSample` dataclasses with `Z`, `coords`, `pol`.
* `decompose_polarizability(pol)` returns the three irrep-norm scalar
targets used by the cross-selectivity comparison: `alpha_iso` (A_1g
trace), `alpha_E` (E_g traceless-diagonal anisotropy), and `alpha_T2`
(T_2g off-diagonal magnitude).
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import h5py
import numpy as np
@dataclass
class QM7XSample:
Z: np.ndarray # (n_atoms,) atomic numbers
coords: np.ndarray # (n_atoms, 3) Cartesian coordinates
pol: np.ndarray # (3, 3) molecular polarizability mTPOL
def load_qm7x_equilibrium(
data_dir: str,
max_molecules: Optional[int] = None,
) -> List[QM7XSample]:
"""Stream every equilibrium ("-opt") conformation from QM7-X HDF5 files.
Args:
data_dir : directory containing one or more *.hdf5 files.
max_molecules : optional cap on samples returned (for smoke tests).
Returns:
list of QM7XSample, one per equilibrium conformation found.
"""
data_dir = Path(data_dir)
samples: List[QM7XSample] = []
for hdf5_path in sorted(data_dir.glob("*.hdf5")):
with h5py.File(str(hdf5_path), "r") as f:
for mol_id in f.keys():
for conf_id in f[mol_id].keys():
if not conf_id.endswith("-opt"):
continue
g = f[mol_id][conf_id]
pol_raw = np.asarray(g["mTPOL"])
# mTPOL is stored as a flattened 9-vector in row-major
# order; reshape to the canonical (3, 3) symmetric form.
if pol_raw.shape == (9,):
pol_mat = pol_raw.reshape(3, 3)
elif pol_raw.shape == (3, 3):
pol_mat = pol_raw
else:
raise ValueError(
f"unexpected mTPOL shape {pol_raw.shape} "
f"in {hdf5_path}::{mol_id}/{conf_id}"
)
samples.append(
QM7XSample(
Z=np.asarray(g["atNUM"]),
coords=np.asarray(g["atXYZ"]),
pol=pol_mat,
)
)
if max_molecules and len(samples) >= max_molecules:
return samples
return samples
def decompose_polarizability(pol: np.ndarray) -> dict:
"""Octahedral irrep decomposition of a 3×3 symmetric polarizability tensor.
Splits a (3, 3) tensor into:
* `alpha_iso` : A_1g component (1/3 trace)
* `alpha_E` : E_g magnitude (norm of the two diagonal-anisotropy
components in real-spherical normalization)
* `alpha_T2` : T_2g magnitude (norm of the three off-diagonal
components, factor sqrt(2) from real-spherical
normalization)
Returns a dict so callers can pick the required scalar target.
"""
pol = np.asarray(pol, dtype=np.float64)
alpha_iso = float(np.trace(pol) / 3.0)
traceless = pol - alpha_iso * np.eye(3)
sym = (traceless + traceless.T) / 2.0
e1 = (sym[0, 0] - sym[1, 1]) / np.sqrt(2.0)
e2 = (2.0 * sym[2, 2] - sym[0, 0] - sym[1, 1]) / np.sqrt(6.0)
alpha_E = float(np.sqrt(e1 ** 2 + e2 ** 2))
alpha_T2 = float(
np.sqrt(sym[0, 1] ** 2 + sym[0, 2] ** 2 + sym[1, 2] ** 2)
* np.sqrt(2.0)
)
return {"alpha_iso": alpha_iso, "alpha_E": alpha_E, "alpha_T2": alpha_T2}
def qm7x_target_array(samples: List[QM7XSample], target: str) -> np.ndarray:
"""Build a (N,) array of one of the polarizability scalar targets."""
if target not in ("alpha_iso", "alpha_E", "alpha_T2"):
raise ValueError(f"unknown qm7x target {target!r}; "
f"expected alpha_iso | alpha_E | alpha_T2")
return np.array(
[decompose_polarizability(s.pol)[target] for s in samples],
dtype=np.float64,
)
def qm7x_split(n_total: int, seed: int = 42, frac_train: float = 0.6,
frac_val: float = 0.2):
"""Reproducible 60/20/20 per-molecule split (per the QM7-X protocol)."""
rng = np.random.default_rng(seed)
perm = rng.permutation(n_total)
n_train = int(frac_train * n_total)
n_val = int(frac_val * n_total)
return (
perm[:n_train],
perm[n_train:n_train + n_val],
perm[n_train + n_val:],
)