tensor-group-sym / python / large_scale / data / featurizers.py
featurizers.py
Raw
"""Molecule → tensor featurization for the ★_G pipeline.

Featurizers produce a `(n_feat, |G|)` molecule-level summary, see
CONTRIBUTING.md for the input-contract discipline. Adding more
features means appending rows; it never means adding a per-atom
dimension.

  * cyclic_angular_features: 14 rows of angular projections under z-axis
    rotations at n equally-spaced angles. Used with G = Z_n.

  * octahedral_features: 14 rows of features under the 24 rotations of
    the chiral octahedral group O. Used in the Wigner-Eckart experiment.

  * coulomb_eig_extended_features: 14 angular rows + 29 Coulomb-matrix
    sorted eigenvalues replicated as invariant rows = (43, |G|) tensor.
    The Coulomb-matrix eigenvalues are a classical Rupp-2012 descriptor
    of inter-atomic distances and atomic numbers; they encode bond-
    topology information that pure angular projections lack, while
    remaining a molecule-level summary (no atom dim, no learnable
    parameters). Illustrative, shows the same algebra scales with
    input richness.

  * raw_atomic_features: per-atom Z, position, and Mulliken charge,
    consumed only by SchNet / e3nn / MACE baselines (not by ★_G).
"""

from __future__ import annotations

from typing import List

import numpy as np
import torch

from .qm9 import QM9Sample
# Absolute import works when running scripts from the large_scale/ dir
# (which adds it to sys.path). Avoids the "beyond top-level package" error
# raised when this file is imported as part of the data subpackage.
from starg_torch.octahedral import octahedral_rotations


def cyclic_angular_features(
    sample: QM9Sample,
    n_rot: int,
    n_feat: int = 14,
) -> np.ndarray:
    """Compute (n_feat, n_rot) tensor of angular features.

    Identical to the MATLAB angular_features() used in the QM9 experiment:
    inner products of atomic positions with a rotating measurement basis at
    angles 2π g / n. Rows include 5 invariant rows (mean Z, distance moments)
    and 9 equivariant rows (sin/cos of angular projection at varying scales).
    """
    pos = sample.coords - sample.coords.mean(axis=0, keepdims=True)
    Z = sample.Z
    angles = 2.0 * np.pi * np.arange(n_rot) / n_rot
    out = np.zeros((n_feat, n_rot))
    for g, theta in enumerate(angles):
        c, s = np.cos(theta), np.sin(theta)
        R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
        rotated = pos @ R.T
        out[0, g] = float(np.mean(Z))
        out[1, g] = float(np.mean(np.linalg.norm(rotated, axis=1)))
        out[2, g] = float(np.std(np.linalg.norm(rotated, axis=1)))
        out[3, g] = float(np.mean(rotated[:, 2] ** 2))
        out[4, g] = float(np.mean(np.abs(rotated[:, 2])))
        # 9 equivariant rows: sin/cos at scales 1, 2, 3 and integrated dipoles
        for scale in range(1, 4):
            out[5 + (scale - 1) * 3, g] = float(np.sum(Z * np.cos(scale * np.arctan2(rotated[:, 1], rotated[:, 0]))))
            out[6 + (scale - 1) * 3, g] = float(np.sum(Z * np.sin(scale * np.arctan2(rotated[:, 1], rotated[:, 0]))))
            out[7 + (scale - 1) * 3, g] = float(np.sum(Z * rotated[:, 0] ** scale))
    return out  # shape (n_feat, n_rot)


def octahedral_features(
    sample: QM9Sample,
    n_feat: int = 14,
) -> np.ndarray:
    """Compute (n_feat, 24) tensor of features under the 24 octahedral rotations."""
    rotations = octahedral_rotations()
    pos = sample.coords - sample.coords.mean(axis=0, keepdims=True)
    Z = sample.Z
    q = sample.charges
    n_rot = 24
    out = np.zeros((n_feat, n_rot))
    for g, R in enumerate(rotations):
        rp = pos @ R.T
        out[0, g] = float(np.mean(Z))
        out[1, g] = float(np.mean(np.linalg.norm(rp, axis=1)))
        out[2, g] = float(np.std(np.linalg.norm(rp, axis=1)))
        # Vector projections (transform as l=1)
        out[3, g] = float(np.sum(Z * rp[:, 0]))
        out[4, g] = float(np.sum(Z * rp[:, 1]))
        out[5, g] = float(np.sum(Z * rp[:, 2]))
        # Mulliken-charge dipole (true rank-1 tensor)
        out[6, g] = float(np.sum(q * rp[:, 0]))
        out[7, g] = float(np.sum(q * rp[:, 1]))
        out[8, g] = float(np.sum(q * rp[:, 2]))
        # Quadrupole-like rank-2 features (xx-yy, 2zz-xx-yy, xy, xz, yz components)
        out[9, g] = float(np.sum(Z * (rp[:, 0] ** 2 - rp[:, 1] ** 2)))
        out[10, g] = float(np.sum(Z * (2 * rp[:, 2] ** 2 - rp[:, 0] ** 2 - rp[:, 1] ** 2)))
        out[11, g] = float(np.sum(Z * rp[:, 0] * rp[:, 1]))
        out[12, g] = float(np.sum(Z * rp[:, 0] * rp[:, 2]))
        out[13, g] = float(np.sum(Z * rp[:, 1] * rp[:, 2]))
    return out


def _coulomb_matrix(coords: np.ndarray, Z: np.ndarray) -> np.ndarray:
    """Coulomb matrix (Rupp et al. 2012). M[i,j] = Z_i Z_j / |r_i - r_j|
    for i != j and 0.5 * Z_i^2.4 on the diagonal."""
    n = len(Z)
    M = np.zeros((n, n))
    for i in range(n):
        M[i, i] = 0.5 * float(Z[i]) ** 2.4
        for j in range(i + 1, n):
            d = float(np.linalg.norm(coords[i] - coords[j]))
            v = float(Z[i]) * float(Z[j]) / max(d, 1e-8)
            M[i, j] = v
            M[j, i] = v
    return M


def coulomb_eig_extended_features(
    sample: QM9Sample,
    n_rot: int = 12,
    n_eig: int = 29,
) -> np.ndarray:
    """Combine the 14-row angular features with `n_eig` sorted Coulomb-
    matrix eigenvalues replicated across the group dimension as invariant
    rows. Returns shape (14 + n_eig, n_rot).

    The Coulomb-matrix eigenvalues are a classical molecule-level descriptor
    that encodes inter-atomic distances and atomic numbers in a permutation-
    and rotation-invariant way. We sort them in descending order, pad to
    length `n_eig` (29 = max QM9 atom count), and tile across the group
    dimension. Because they are invariant, every group element sees the
    same value, which is the right semantic for a property of the molecule
    itself.

    Discipline: the output is still `(n_feat, n_rot)` per molecule, no atom
    dimension, no learnable parameters. See CONTRIBUTING.md.
    """
    angular = cyclic_angular_features(sample, n_rot=n_rot, n_feat=14)
    cm = _coulomb_matrix(sample.coords, sample.Z)
    eig = np.sort(np.linalg.eigvalsh(cm))[::-1]
    eig = np.pad(eig, (0, max(0, n_eig - len(eig))))[:n_eig]
    inv_rows = np.tile(eig[:, None], (1, n_rot))
    return np.vstack([angular, inv_rows])


def raw_atomic_features(sample: QM9Sample) -> dict:
    """Return a dict with keys (z, pos, charges) for ENN baselines."""
    return {
        "z": torch.tensor(sample.Z, dtype=torch.long),
        "pos": torch.tensor(sample.coords, dtype=torch.float32),
        "charges": torch.tensor(sample.charges, dtype=torch.float32),
    }


def stack_batch(samples: List[QM9Sample], featurizer, **kwargs) -> torch.Tensor:
    """Apply a featurizer to a list of samples and stack into (N, n_feat, n_g)."""
    feats = [featurizer(s, **kwargs) for s in samples]
    return torch.tensor(np.stack(feats, axis=0), dtype=torch.float32)