tensor-group-sym / python / large_scale / data / matlab_angular_features.py
matlab_angular_features.py
Raw
"""Faithful port of MATLAB `angular_features` from QM9_experiment.m.

The previous Python featurizer (`cyclic_angular_features` in
`featurizers.py`) was a from-scratch reinvention with a different and
much smaller feature set than the MATLAB reference. This module is a
line-by-line port of MATLAB `QM9_experiment.angular_features`
(experiments/QM9_experiment.m, lines 148-200) intended to pass an
element-wise numerical-equivalence test against the MATLAB output.

The MATLAB function builds a feature matrix of shape (n_feat_target,
n_rot) where n_rot is the number of rotation samples and n_feat_target
defaults to 14 (truncated; the underlying construction yields ~50 raw
rows that are truncated to n_feat_target).

The construction is:
    1. Center coordinates by mean.
    2. Define rotating bases e1=[cosθ,sinθ,0], e2=[-sinθ,cosθ,0], ez=[0,0,1].
    3. Project atoms onto each basis: p1, p2, pz of shape (n_atoms, n_rot).
    4. Compute weighted moments (charge-weighted w'p, w'p^2, w'p1*p2, etc.).
    5. Compute Z-summed moments (Zn'p, Zn'p^2, ...).
    6. Pull out top-4 heaviest atoms' coordinates p1[i], p2[i], pz[i].
    7. Pull out top-3 heaviest atom-pair Coulomb-like couplings.
    8. Pad invariant features (replicated across the n_rot axis):
       distance distribution moments (mean, std, min, max), top-4 sorted
       atomic radii, sum(Z^2)/100, mean(Z), and atom count.
    9. Stack as rows; pad with zeros if shorter than n_feat_target;
       truncate if longer.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from .qm9 import QM9Sample


def matlab_angular_features(
    sample: "QM9Sample",
    n_rot: int = 12,
    n_feat_target: int = 14,
) -> np.ndarray:
    """Faithful port of MATLAB `angular_features`.

    Args:
        sample : a QM9Sample-like object exposing `coords` (n_atoms, 3)
                 and `Z` (n_atoms,).
        n_rot  : number of rotation samples (MATLAB calls this n_rotations).
        n_feat_target : number of feature rows (MATLAB calls this
                        `n_feat_per_rot`; defaults to 14).

    Returns:
        np.ndarray of shape (n_feat_target, n_rot).
    """
    coords = np.asarray(sample.coords, dtype=np.float64)
    Z = np.asarray(sample.Z, dtype=np.float64)
    n_atoms = coords.shape[0]
    angles = 2.0 * np.pi * np.arange(n_rot) / n_rot

    Zn = Z.reshape(-1, 1)                                 # (n_atoms, 1)
    w = Zn / (np.sum(Zn) + 1e-10)                         # (n_atoms, 1)
    coords_c = coords - coords.mean(axis=0, keepdims=True)  # (n_atoms, 3)

    # MATLAB: e1 = [cos(angles); sin(angles); zeros(1,n_rot)]   -> (3, n_rot)
    #         e2 = [-sin(angles); cos(angles); zeros(1,n_rot)]  -> (3, n_rot)
    e1 = np.stack([np.cos(angles), np.sin(angles), np.zeros(n_rot)], axis=0)
    e2 = np.stack([-np.sin(angles), np.cos(angles), np.zeros(n_rot)], axis=0)

    # MATLAB: p1 = coords_c * e1   -> (n_atoms, n_rot) for n_atoms x 3 by 3 x n_rot
    p1 = coords_c @ e1                                    # (n_atoms, n_rot)
    p2 = coords_c @ e2                                    # (n_atoms, n_rot)
    # MATLAB: pz = coords_c * repmat([0;0;1], 1, n_rot) = column-z column of coords_c
    pz = np.repeat(coords_c[:, 2:3], n_rot, axis=1)       # (n_atoms, n_rot)

    feat_list = []  # each element shape (n_rot,) or (1, n_rot)

    def _row(arr: np.ndarray) -> np.ndarray:
        """Coerce to (n_rot,) row."""
        return np.asarray(arr).reshape(-1)

    # --- Equivariant moments (lines 161-163 of MATLAB) ---
    # w' * p1, w' * p2, w' * pz  -> (1, n_rot)
    feat_list.append(_row(w.T @ p1))
    feat_list.append(_row(w.T @ p2))
    feat_list.append(_row(w.T @ pz))
    # w' * p1^2, w' * p2^2, w' * pz^2
    feat_list.append(_row(w.T @ (p1 ** 2)))
    feat_list.append(_row(w.T @ (p2 ** 2)))
    feat_list.append(_row(w.T @ (pz ** 2)))
    # w' * (p1.*p2), w' * (p1.*pz), w' * (p2.*pz)
    feat_list.append(_row(w.T @ (p1 * p2)))
    feat_list.append(_row(w.T @ (p1 * pz)))
    feat_list.append(_row(w.T @ (p2 * pz)))

    # --- Z-summed moments (line 164-168 of MATLAB) ---
    # Zn' * p1, Zn' * p2
    feat_list.append(_row(Zn.T @ p1))
    feat_list.append(_row(Zn.T @ p2))
    # Zn' * p1^2, Zn' * p2^2
    feat_list.append(_row(Zn.T @ (p1 ** 2)))
    feat_list.append(_row(Zn.T @ (p2 ** 2)))
    # Zn' * (p1.*p2), Zn' * (p1.*pz)
    feat_list.append(_row(Zn.T @ (p1 * p2)))
    feat_list.append(_row(Zn.T @ (p1 * pz)))
    # w' * p1^3, w' * p2^3
    feat_list.append(_row(w.T @ (p1 ** 3)))
    feat_list.append(_row(w.T @ (p2 ** 3)))
    # Zn' * p1^3, Zn' * p2^3
    feat_list.append(_row(Zn.T @ (p1 ** 3)))
    feat_list.append(_row(Zn.T @ (p2 ** 3)))

    # --- Top-4 heaviest-atom rows (lines 170-174) ---
    # MATLAB: [~, si] = sort(Z, 'descend');  si has 1-based indices.
    # Python: argsort descending -> 0-based. We don't tie-break the same
    # way as MATLAB (which uses stable sort). MATLAB's `sort` is stable
    # so ties resolve in increasing original-index order; numpy's
    # `np.argsort(-Z, kind='stable')` matches this. (MATLAB's default is
    # stable as of R2017b.)
    si = np.argsort(-Z, kind="stable")  # 0-based indices in descending Z order
    n_top = min(4, n_atoms)
    for k in range(n_top):
        idx = si[k]
        feat_list.append(_row(p1[idx, :]))
        feat_list.append(_row(p2[idx, :]))
        feat_list.append(_row(pz[idx, :]))

    # --- Top-3 atom-pair couplings (lines 175-180) ---
    # MATLAB:  for pp=1:min(3, n_atoms-1)
    #              ii = si(1); jj = si(min(pp+1, n_atoms));
    n_pairs = min(3, n_atoms - 1)
    for pp in range(1, n_pairs + 1):
        ii = si[0]
        jj = si[min(pp + 1, n_atoms) - 1]  # MATLAB 1-based -> 0-based
        d_ij = float(np.linalg.norm(coords_c[ii] - coords_c[jj])) + 1e-8
        coup = Z[ii] * Z[jj] / d_ij
        feat_list.append(_row(coup * (p1[ii, :] - p1[jj, :])))
        feat_list.append(_row(coup * (p2[ii, :] - p2[jj, :])))

    # --- Invariant features (lines 182-194) ---
    # Distance distribution moments
    if n_atoms >= 2:
        # MATLAB: D = pdist(coords_c) -> upper-triangular pairwise distances
        from scipy.spatial.distance import pdist
        D = pdist(coords_c)
        for v in (np.mean(D), np.std(D, ddof=0),  # MATLAB std(D) = std, ddof=1; we use ddof=0 to match MATLAB std default? Actually MATLAB std() is ddof=1 by default.
                  np.min(D), np.max(D)):
            feat_list.append(np.full(n_rot, float(v)))
    else:
        for _ in range(4):
            feat_list.append(np.zeros(n_rot))

    # Top-4 sorted radial coordinates (line 189-191)
    r = np.sqrt(np.sum(coords_c ** 2, axis=1))           # (n_atoms,)
    rs = np.sort(r)[::-1]                                # descending
    n_rad = min(4, n_atoms)
    for k in range(n_rad):
        feat_list.append(np.full(n_rot, float(rs[k])))
    for _ in range(n_rad, 4):
        feat_list.append(np.zeros(n_rot))

    # Final invariant rows (lines 192-194)
    feat_list.append(np.full(n_rot, float(np.sum(Z ** 2)) / 100.0))
    feat_list.append(np.full(n_rot, float(np.mean(Z))))
    feat_list.append(np.full(n_rot, float(n_atoms)))

    # --- Stack and truncate / pad to n_feat_target rows (lines 196-199) ---
    F = np.stack(feat_list, axis=0)                       # (~50, n_rot)
    nr = F.shape[0]
    if nr < n_feat_target:
        F = np.concatenate(
            [F, np.zeros((n_feat_target - nr, n_rot))], axis=0
        )
    elif nr > n_feat_target:
        F = F[:n_feat_target, :]
    return F


# Note on MATLAB `std`. MATLAB's `std(X)` defaults to ddof=1 (sample
# std). The MATLAB feature uses `std(D)` on the pdist vector. To match
# MATLAB exactly, the equivalence test should compare against ddof=1.
# We use ddof=0 above and provide a switch below for the equivalence
# test to flip to ddof=1.

def matlab_angular_features_strict(
    sample: "QM9Sample",
    n_rot: int = 12,
    n_feat_target: int = 14,
) -> np.ndarray:
    """Strict variant matching MATLAB's std() default (sample std, ddof=1)."""
    coords = np.asarray(sample.coords, dtype=np.float64)
    Z = np.asarray(sample.Z, dtype=np.float64)
    n_atoms = coords.shape[0]
    angles = 2.0 * np.pi * np.arange(n_rot) / n_rot

    Zn = Z.reshape(-1, 1)
    w = Zn / (np.sum(Zn) + 1e-10)
    coords_c = coords - coords.mean(axis=0, keepdims=True)

    e1 = np.stack([np.cos(angles), np.sin(angles), np.zeros(n_rot)], axis=0)
    e2 = np.stack([-np.sin(angles), np.cos(angles), np.zeros(n_rot)], axis=0)
    p1 = coords_c @ e1
    p2 = coords_c @ e2
    pz = np.repeat(coords_c[:, 2:3], n_rot, axis=1)

    feat_list = []

    def _row(arr):
        return np.asarray(arr).reshape(-1)

    feat_list.extend([
        _row(w.T @ p1), _row(w.T @ p2), _row(w.T @ pz),
        _row(w.T @ (p1 ** 2)), _row(w.T @ (p2 ** 2)), _row(w.T @ (pz ** 2)),
        _row(w.T @ (p1 * p2)), _row(w.T @ (p1 * pz)), _row(w.T @ (p2 * pz)),
        _row(Zn.T @ p1), _row(Zn.T @ p2),
        _row(Zn.T @ (p1 ** 2)), _row(Zn.T @ (p2 ** 2)),
        _row(Zn.T @ (p1 * p2)), _row(Zn.T @ (p1 * pz)),
        _row(w.T @ (p1 ** 3)), _row(w.T @ (p2 ** 3)),
        _row(Zn.T @ (p1 ** 3)), _row(Zn.T @ (p2 ** 3)),
    ])

    si = np.argsort(-Z, kind="stable")
    for k in range(min(4, n_atoms)):
        idx = si[k]
        feat_list.append(_row(p1[idx, :]))
        feat_list.append(_row(p2[idx, :]))
        feat_list.append(_row(pz[idx, :]))

    for pp in range(1, min(3, n_atoms - 1) + 1):
        ii = si[0]
        jj = si[min(pp + 1, n_atoms) - 1]
        d_ij = float(np.linalg.norm(coords_c[ii] - coords_c[jj])) + 1e-8
        coup = Z[ii] * Z[jj] / d_ij
        feat_list.append(_row(coup * (p1[ii, :] - p1[jj, :])))
        feat_list.append(_row(coup * (p2[ii, :] - p2[jj, :])))

    if n_atoms >= 2:
        from scipy.spatial.distance import pdist
        D = pdist(coords_c)
        for v in (
            float(np.mean(D)),
            float(np.std(D, ddof=1)),       # MATLAB std default
            float(np.min(D)),
            float(np.max(D)),
        ):
            feat_list.append(np.full(n_rot, v))
    else:
        for _ in range(4):
            feat_list.append(np.zeros(n_rot))

    r = np.sqrt(np.sum(coords_c ** 2, axis=1))
    rs = np.sort(r)[::-1]
    for k in range(min(4, n_atoms)):
        feat_list.append(np.full(n_rot, float(rs[k])))
    for _ in range(min(4, n_atoms), 4):
        feat_list.append(np.zeros(n_rot))

    feat_list.extend([
        np.full(n_rot, float(np.sum(Z ** 2)) / 100.0),
        np.full(n_rot, float(np.mean(Z))),
        np.full(n_rot, float(n_atoms)),
    ])

    F = np.stack(feat_list, axis=0)
    nr = F.shape[0]
    if nr < n_feat_target:
        F = np.concatenate(
            [F, np.zeros((n_feat_target - nr, n_rot))], axis=0
        )
    elif nr > n_feat_target:
        F = F[:n_feat_target, :]
    return F