"""Faithful port of MATLAB `angular_features` from QM9_experiment.m. The previous Python featurizer (`cyclic_angular_features` in `featurizers.py`) was a from-scratch reinvention with a different and much smaller feature set than the MATLAB reference. This module is a line-by-line port of MATLAB `QM9_experiment.angular_features` (experiments/QM9_experiment.m, lines 148-200) intended to pass an element-wise numerical-equivalence test against the MATLAB output. The MATLAB function builds a feature matrix of shape (n_feat_target, n_rot) where n_rot is the number of rotation samples and n_feat_target defaults to 14 (truncated; the underlying construction yields ~50 raw rows that are truncated to n_feat_target). The construction is: 1. Center coordinates by mean. 2. Define rotating bases e1=[cosθ,sinθ,0], e2=[-sinθ,cosθ,0], ez=[0,0,1]. 3. Project atoms onto each basis: p1, p2, pz of shape (n_atoms, n_rot). 4. Compute weighted moments (charge-weighted w'p, w'p^2, w'p1*p2, etc.). 5. Compute Z-summed moments (Zn'p, Zn'p^2, ...). 6. Pull out top-4 heaviest atoms' coordinates p1[i], p2[i], pz[i]. 7. Pull out top-3 heaviest atom-pair Coulomb-like couplings. 8. Pad invariant features (replicated across the n_rot axis): distance distribution moments (mean, std, min, max), top-4 sorted atomic radii, sum(Z^2)/100, mean(Z), and atom count. 9. Stack as rows; pad with zeros if shorter than n_feat_target; truncate if longer. """ from __future__ import annotations from typing import TYPE_CHECKING import numpy as np if TYPE_CHECKING: from .qm9 import QM9Sample def matlab_angular_features( sample: "QM9Sample", n_rot: int = 12, n_feat_target: int = 14, ) -> np.ndarray: """Faithful port of MATLAB `angular_features`. Args: sample : a QM9Sample-like object exposing `coords` (n_atoms, 3) and `Z` (n_atoms,). n_rot : number of rotation samples (MATLAB calls this n_rotations). n_feat_target : number of feature rows (MATLAB calls this `n_feat_per_rot`; defaults to 14). Returns: np.ndarray of shape (n_feat_target, n_rot). """ coords = np.asarray(sample.coords, dtype=np.float64) Z = np.asarray(sample.Z, dtype=np.float64) n_atoms = coords.shape[0] angles = 2.0 * np.pi * np.arange(n_rot) / n_rot Zn = Z.reshape(-1, 1) # (n_atoms, 1) w = Zn / (np.sum(Zn) + 1e-10) # (n_atoms, 1) coords_c = coords - coords.mean(axis=0, keepdims=True) # (n_atoms, 3) # MATLAB: e1 = [cos(angles); sin(angles); zeros(1,n_rot)] -> (3, n_rot) # e2 = [-sin(angles); cos(angles); zeros(1,n_rot)] -> (3, n_rot) e1 = np.stack([np.cos(angles), np.sin(angles), np.zeros(n_rot)], axis=0) e2 = np.stack([-np.sin(angles), np.cos(angles), np.zeros(n_rot)], axis=0) # MATLAB: p1 = coords_c * e1 -> (n_atoms, n_rot) for n_atoms x 3 by 3 x n_rot p1 = coords_c @ e1 # (n_atoms, n_rot) p2 = coords_c @ e2 # (n_atoms, n_rot) # MATLAB: pz = coords_c * repmat([0;0;1], 1, n_rot) = column-z column of coords_c pz = np.repeat(coords_c[:, 2:3], n_rot, axis=1) # (n_atoms, n_rot) feat_list = [] # each element shape (n_rot,) or (1, n_rot) def _row(arr: np.ndarray) -> np.ndarray: """Coerce to (n_rot,) row.""" return np.asarray(arr).reshape(-1) # --- Equivariant moments (lines 161-163 of MATLAB) --- # w' * p1, w' * p2, w' * pz -> (1, n_rot) feat_list.append(_row(w.T @ p1)) feat_list.append(_row(w.T @ p2)) feat_list.append(_row(w.T @ pz)) # w' * p1^2, w' * p2^2, w' * pz^2 feat_list.append(_row(w.T @ (p1 ** 2))) feat_list.append(_row(w.T @ (p2 ** 2))) feat_list.append(_row(w.T @ (pz ** 2))) # w' * (p1.*p2), w' * (p1.*pz), w' * (p2.*pz) feat_list.append(_row(w.T @ (p1 * p2))) feat_list.append(_row(w.T @ (p1 * pz))) feat_list.append(_row(w.T @ (p2 * pz))) # --- Z-summed moments (line 164-168 of MATLAB) --- # Zn' * p1, Zn' * p2 feat_list.append(_row(Zn.T @ p1)) feat_list.append(_row(Zn.T @ p2)) # Zn' * p1^2, Zn' * p2^2 feat_list.append(_row(Zn.T @ (p1 ** 2))) feat_list.append(_row(Zn.T @ (p2 ** 2))) # Zn' * (p1.*p2), Zn' * (p1.*pz) feat_list.append(_row(Zn.T @ (p1 * p2))) feat_list.append(_row(Zn.T @ (p1 * pz))) # w' * p1^3, w' * p2^3 feat_list.append(_row(w.T @ (p1 ** 3))) feat_list.append(_row(w.T @ (p2 ** 3))) # Zn' * p1^3, Zn' * p2^3 feat_list.append(_row(Zn.T @ (p1 ** 3))) feat_list.append(_row(Zn.T @ (p2 ** 3))) # --- Top-4 heaviest-atom rows (lines 170-174) --- # MATLAB: [~, si] = sort(Z, 'descend'); si has 1-based indices. # Python: argsort descending -> 0-based. We don't tie-break the same # way as MATLAB (which uses stable sort). MATLAB's `sort` is stable # so ties resolve in increasing original-index order; numpy's # `np.argsort(-Z, kind='stable')` matches this. (MATLAB's default is # stable as of R2017b.) si = np.argsort(-Z, kind="stable") # 0-based indices in descending Z order n_top = min(4, n_atoms) for k in range(n_top): idx = si[k] feat_list.append(_row(p1[idx, :])) feat_list.append(_row(p2[idx, :])) feat_list.append(_row(pz[idx, :])) # --- Top-3 atom-pair couplings (lines 175-180) --- # MATLAB: for pp=1:min(3, n_atoms-1) # ii = si(1); jj = si(min(pp+1, n_atoms)); n_pairs = min(3, n_atoms - 1) for pp in range(1, n_pairs + 1): ii = si[0] jj = si[min(pp + 1, n_atoms) - 1] # MATLAB 1-based -> 0-based d_ij = float(np.linalg.norm(coords_c[ii] - coords_c[jj])) + 1e-8 coup = Z[ii] * Z[jj] / d_ij feat_list.append(_row(coup * (p1[ii, :] - p1[jj, :]))) feat_list.append(_row(coup * (p2[ii, :] - p2[jj, :]))) # --- Invariant features (lines 182-194) --- # Distance distribution moments if n_atoms >= 2: # MATLAB: D = pdist(coords_c) -> upper-triangular pairwise distances from scipy.spatial.distance import pdist D = pdist(coords_c) for v in (np.mean(D), np.std(D, ddof=0), # MATLAB std(D) = std, ddof=1; we use ddof=0 to match MATLAB std default? Actually MATLAB std() is ddof=1 by default. np.min(D), np.max(D)): feat_list.append(np.full(n_rot, float(v))) else: for _ in range(4): feat_list.append(np.zeros(n_rot)) # Top-4 sorted radial coordinates (line 189-191) r = np.sqrt(np.sum(coords_c ** 2, axis=1)) # (n_atoms,) rs = np.sort(r)[::-1] # descending n_rad = min(4, n_atoms) for k in range(n_rad): feat_list.append(np.full(n_rot, float(rs[k]))) for _ in range(n_rad, 4): feat_list.append(np.zeros(n_rot)) # Final invariant rows (lines 192-194) feat_list.append(np.full(n_rot, float(np.sum(Z ** 2)) / 100.0)) feat_list.append(np.full(n_rot, float(np.mean(Z)))) feat_list.append(np.full(n_rot, float(n_atoms))) # --- Stack and truncate / pad to n_feat_target rows (lines 196-199) --- F = np.stack(feat_list, axis=0) # (~50, n_rot) nr = F.shape[0] if nr < n_feat_target: F = np.concatenate( [F, np.zeros((n_feat_target - nr, n_rot))], axis=0 ) elif nr > n_feat_target: F = F[:n_feat_target, :] return F # Note on MATLAB `std`. MATLAB's `std(X)` defaults to ddof=1 (sample # std). The MATLAB feature uses `std(D)` on the pdist vector. To match # MATLAB exactly, the equivalence test should compare against ddof=1. # We use ddof=0 above and provide a switch below for the equivalence # test to flip to ddof=1. def matlab_angular_features_strict( sample: "QM9Sample", n_rot: int = 12, n_feat_target: int = 14, ) -> np.ndarray: """Strict variant matching MATLAB's std() default (sample std, ddof=1).""" coords = np.asarray(sample.coords, dtype=np.float64) Z = np.asarray(sample.Z, dtype=np.float64) n_atoms = coords.shape[0] angles = 2.0 * np.pi * np.arange(n_rot) / n_rot Zn = Z.reshape(-1, 1) w = Zn / (np.sum(Zn) + 1e-10) coords_c = coords - coords.mean(axis=0, keepdims=True) e1 = np.stack([np.cos(angles), np.sin(angles), np.zeros(n_rot)], axis=0) e2 = np.stack([-np.sin(angles), np.cos(angles), np.zeros(n_rot)], axis=0) p1 = coords_c @ e1 p2 = coords_c @ e2 pz = np.repeat(coords_c[:, 2:3], n_rot, axis=1) feat_list = [] def _row(arr): return np.asarray(arr).reshape(-1) feat_list.extend([ _row(w.T @ p1), _row(w.T @ p2), _row(w.T @ pz), _row(w.T @ (p1 ** 2)), _row(w.T @ (p2 ** 2)), _row(w.T @ (pz ** 2)), _row(w.T @ (p1 * p2)), _row(w.T @ (p1 * pz)), _row(w.T @ (p2 * pz)), _row(Zn.T @ p1), _row(Zn.T @ p2), _row(Zn.T @ (p1 ** 2)), _row(Zn.T @ (p2 ** 2)), _row(Zn.T @ (p1 * p2)), _row(Zn.T @ (p1 * pz)), _row(w.T @ (p1 ** 3)), _row(w.T @ (p2 ** 3)), _row(Zn.T @ (p1 ** 3)), _row(Zn.T @ (p2 ** 3)), ]) si = np.argsort(-Z, kind="stable") for k in range(min(4, n_atoms)): idx = si[k] feat_list.append(_row(p1[idx, :])) feat_list.append(_row(p2[idx, :])) feat_list.append(_row(pz[idx, :])) for pp in range(1, min(3, n_atoms - 1) + 1): ii = si[0] jj = si[min(pp + 1, n_atoms) - 1] d_ij = float(np.linalg.norm(coords_c[ii] - coords_c[jj])) + 1e-8 coup = Z[ii] * Z[jj] / d_ij feat_list.append(_row(coup * (p1[ii, :] - p1[jj, :]))) feat_list.append(_row(coup * (p2[ii, :] - p2[jj, :]))) if n_atoms >= 2: from scipy.spatial.distance import pdist D = pdist(coords_c) for v in ( float(np.mean(D)), float(np.std(D, ddof=1)), # MATLAB std default float(np.min(D)), float(np.max(D)), ): feat_list.append(np.full(n_rot, v)) else: for _ in range(4): feat_list.append(np.zeros(n_rot)) r = np.sqrt(np.sum(coords_c ** 2, axis=1)) rs = np.sort(r)[::-1] for k in range(min(4, n_atoms)): feat_list.append(np.full(n_rot, float(rs[k]))) for _ in range(min(4, n_atoms), 4): feat_list.append(np.zeros(n_rot)) feat_list.extend([ np.full(n_rot, float(np.sum(Z ** 2)) / 100.0), np.full(n_rot, float(np.mean(Z))), np.full(n_rot, float(n_atoms)), ]) F = np.stack(feat_list, axis=0) nr = F.shape[0] if nr < n_feat_target: F = np.concatenate( [F, np.zeros((n_feat_target - nr, n_rot))], axis=0 ) elif nr > n_feat_target: F = F[:n_feat_target, :] return F