"""QM7-X equilibrium-conformation loader for the polarizability experiment. QM7-X HDF5 schema: each file (1000.hdf5, 2000.hdf5, ...) contains molecule-id groups at the top level; each molecule has one or more conformation-id sub-groups. The equilibrium structure carries a "-opt" suffix in the conformation id, which we use as the per-molecule selector. This loader is deliberately independent of any external code (no imports from sibling repositories, no reuse of pre-existing QM7-X experiment scripts). It implements only what `python/large_scale/` needs: * `load_qm7x_equilibrium(data_dir, max_molecules=None)` returns a list of `QM7XSample` dataclasses with `Z`, `coords`, `pol`. * `decompose_polarizability(pol)` returns the three irrep-norm scalar targets used by the cross-selectivity comparison: `alpha_iso` (A_1g trace), `alpha_E` (E_g traceless-diagonal anisotropy), and `alpha_T2` (T_2g off-diagonal magnitude). """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import List, Optional import h5py import numpy as np @dataclass class QM7XSample: Z: np.ndarray # (n_atoms,) atomic numbers coords: np.ndarray # (n_atoms, 3) Cartesian coordinates pol: np.ndarray # (3, 3) molecular polarizability mTPOL def load_qm7x_equilibrium( data_dir: str, max_molecules: Optional[int] = None, ) -> List[QM7XSample]: """Stream every equilibrium ("-opt") conformation from QM7-X HDF5 files. Args: data_dir : directory containing one or more *.hdf5 files. max_molecules : optional cap on samples returned (for smoke tests). Returns: list of QM7XSample, one per equilibrium conformation found. """ data_dir = Path(data_dir) samples: List[QM7XSample] = [] for hdf5_path in sorted(data_dir.glob("*.hdf5")): with h5py.File(str(hdf5_path), "r") as f: for mol_id in f.keys(): for conf_id in f[mol_id].keys(): if not conf_id.endswith("-opt"): continue g = f[mol_id][conf_id] pol_raw = np.asarray(g["mTPOL"]) # mTPOL is stored as a flattened 9-vector in row-major # order; reshape to the canonical (3, 3) symmetric form. if pol_raw.shape == (9,): pol_mat = pol_raw.reshape(3, 3) elif pol_raw.shape == (3, 3): pol_mat = pol_raw else: raise ValueError( f"unexpected mTPOL shape {pol_raw.shape} " f"in {hdf5_path}::{mol_id}/{conf_id}" ) samples.append( QM7XSample( Z=np.asarray(g["atNUM"]), coords=np.asarray(g["atXYZ"]), pol=pol_mat, ) ) if max_molecules and len(samples) >= max_molecules: return samples return samples def decompose_polarizability(pol: np.ndarray) -> dict: """Octahedral irrep decomposition of a 3×3 symmetric polarizability tensor. Splits a (3, 3) tensor into: * `alpha_iso` : A_1g component (1/3 trace) * `alpha_E` : E_g magnitude (norm of the two diagonal-anisotropy components in real-spherical normalization) * `alpha_T2` : T_2g magnitude (norm of the three off-diagonal components, factor sqrt(2) from real-spherical normalization) Returns a dict so callers can pick the required scalar target. """ pol = np.asarray(pol, dtype=np.float64) alpha_iso = float(np.trace(pol) / 3.0) traceless = pol - alpha_iso * np.eye(3) sym = (traceless + traceless.T) / 2.0 e1 = (sym[0, 0] - sym[1, 1]) / np.sqrt(2.0) e2 = (2.0 * sym[2, 2] - sym[0, 0] - sym[1, 1]) / np.sqrt(6.0) alpha_E = float(np.sqrt(e1 ** 2 + e2 ** 2)) alpha_T2 = float( np.sqrt(sym[0, 1] ** 2 + sym[0, 2] ** 2 + sym[1, 2] ** 2) * np.sqrt(2.0) ) return {"alpha_iso": alpha_iso, "alpha_E": alpha_E, "alpha_T2": alpha_T2} def qm7x_target_array(samples: List[QM7XSample], target: str) -> np.ndarray: """Build a (N,) array of one of the polarizability scalar targets.""" if target not in ("alpha_iso", "alpha_E", "alpha_T2"): raise ValueError(f"unknown qm7x target {target!r}; " f"expected alpha_iso | alpha_E | alpha_T2") return np.array( [decompose_polarizability(s.pol)[target] for s in samples], dtype=np.float64, ) def qm7x_split(n_total: int, seed: int = 42, frac_train: float = 0.6, frac_val: float = 0.2): """Reproducible 60/20/20 per-molecule split (per the QM7-X protocol).""" rng = np.random.default_rng(seed) perm = rng.permutation(n_total) n_train = int(frac_train * n_total) n_val = int(frac_val * n_total) return ( perm[:n_train], perm[n_train:n_train + n_val], perm[n_train + n_val:], )