"""Numerical-equivalence test for the MATLAB angular_features port. Writes a fixture of 100 synthetic molecules to a .mat file, calls the MATLAB QM9_experiment.angular_features on each, then compares against the Python port element-wise. Pass criterion: max |MATLAB - Python| < 1e-10 per element. """ from __future__ import annotations import argparse import subprocess import sys from pathlib import Path import numpy as np from scipy.io import loadmat, savemat # Local import; ensure parent dir on path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from data.matlab_angular_features import ( matlab_angular_features, matlab_angular_features_strict, ) class _Sample: """Minimal stand-in for QM9Sample.""" def __init__(self, coords, Z): self.coords = coords self.Z = Z def make_fixture(n_mol: int = 100, max_atoms: int = 12, seed: int = 1234): """Build n_mol synthetic molecules with random coords + Z in {1,6,7,8,9}.""" rng = np.random.default_rng(seed) samples = [] for _ in range(n_mol): n_atoms = int(rng.integers(2, max_atoms + 1)) coords = rng.normal(size=(n_atoms, 3)) * 1.5 Z = rng.choice([1, 6, 7, 8, 9], size=n_atoms).astype(np.float64) samples.append(_Sample(coords=coords, Z=Z)) return samples def write_mat_fixture(samples, path: Path): """Write the fixture to a .mat file MATLAB can load.""" coords_cell = np.empty(len(samples), dtype=object) Z_cell = np.empty(len(samples), dtype=object) n_atoms_arr = np.zeros(len(samples), dtype=np.int32) for i, s in enumerate(samples): coords_cell[i] = s.coords Z_cell[i] = s.Z.reshape(-1, 1) n_atoms_arr[i] = s.coords.shape[0] savemat(str(path), { "coords_cell": coords_cell, "Z_cell": Z_cell, "n_atoms_arr": n_atoms_arr, }) MATLAB_DRIVER = r""" function run_matlab_features(in_path, out_path, n_rot, n_feat_target) addpath(fullfile(fileparts(mfilename('fullpath')), '..', '..', '..', 'experiments')); s = load(in_path); n_mol = numel(s.coords_cell); out = zeros(n_feat_target, n_rot, n_mol); angles = (0:n_rot-1) * 2 * pi / n_rot; % Inline copy of the MATLAB angular_features body so we don't % depend on instantiating the QM9_experiment class (which loads % real QM9 data on construction). for mi = 1:n_mol coords = s.coords_cell{mi}; Z = s.Z_cell{mi}; out(:, :, mi) = angular_features_inline(coords, Z, angles, n_feat_target); end save(out_path, 'out', '-v7'); end function F = angular_features_inline(coords, Z, angles, n_feat_target) n_rot = length(angles); n_atoms = size(coords, 1); Zn = Z(:); w = Zn / (sum(Zn) + 1e-10); coords_c = coords - mean(coords, 1); feat_list = {}; e1 = [cos(angles); sin(angles); zeros(1, n_rot)]; e2 = [-sin(angles); cos(angles); zeros(1, n_rot)]; p1 = coords_c * e1; p2 = coords_c * e2; pz = coords_c * repmat([0; 0; 1], 1, n_rot); feat_list{end+1} = w' * p1; feat_list{end+1} = w' * p2; feat_list{end+1} = w' * pz; feat_list{end+1} = w' * (p1.^2); feat_list{end+1} = w' * (p2.^2); feat_list{end+1} = w' * (pz.^2); feat_list{end+1} = w' * (p1 .* p2); feat_list{end+1} = w' * (p1 .* pz); feat_list{end+1} = w' * (p2 .* pz); feat_list{end+1} = Zn' * p1; feat_list{end+1} = Zn' * p2; feat_list{end+1} = Zn' * (p1.^2); feat_list{end+1} = Zn' * (p2.^2); feat_list{end+1} = Zn' * (p1 .* p2); feat_list{end+1} = Zn' * (p1 .* pz); feat_list{end+1} = w' * (p1.^3); feat_list{end+1} = w' * (p2.^3); feat_list{end+1} = Zn' * (p1.^3); feat_list{end+1} = Zn' * (p2.^3); [~, si] = sort(Z, 'descend'); for k = 1:min(4, n_atoms) idx = si(k); feat_list{end+1} = p1(idx, :); feat_list{end+1} = p2(idx, :); feat_list{end+1} = pz(idx, :); end for pp = 1:min(3, n_atoms - 1) ii = si(1); jj = si(min(pp + 1, n_atoms)); d_ij = norm(coords_c(ii, :) - coords_c(jj, :)) + 1e-8; feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p1(ii, :) - p1(jj, :)); feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p2(ii, :) - p2(jj, :)); end if n_atoms >= 2 D = pdist(coords_c); for v = [mean(D), std(D), min(D), max(D)] feat_list{end+1} = repmat(v, 1, n_rot); end else for k = 1:4 feat_list{end+1} = zeros(1, n_rot); end end r = sqrt(sum(coords_c.^2, 2)); rs = sort(r, 'descend'); for k = 1:min(4, n_atoms) feat_list{end+1} = repmat(rs(k), 1, n_rot); end for k = n_atoms+1 : 4 feat_list{end+1} = zeros(1, n_rot); end feat_list{end+1} = repmat(sum(Zn.^2) / 100, 1, n_rot); feat_list{end+1} = repmat(mean(Zn), 1, n_rot); feat_list{end+1} = repmat(n_atoms, 1, n_rot); F = cell2mat(feat_list'); nr = size(F, 1); if nr < n_feat_target F = [F; zeros(n_feat_target - nr, n_rot)]; elseif nr > n_feat_target F = F(1:n_feat_target, :); end end """ def main(): ap = argparse.ArgumentParser() ap.add_argument("--n_mol", type=int, default=100) ap.add_argument("--n_rot", type=int, default=12) ap.add_argument("--n_feat", type=int, default=14) ap.add_argument("--tmp_dir", default="C:/Temp/starg_eqtest") args = ap.parse_args() tmp = Path(args.tmp_dir) tmp.mkdir(parents=True, exist_ok=True) fixture_path = tmp / "fixture.mat" matlab_out_path = tmp / "matlab_out.mat" matlab_driver_path = tmp / "run_matlab_features.m" print(f"[setup] generating {args.n_mol} synthetic molecules ...") samples = make_fixture(n_mol=args.n_mol) write_mat_fixture(samples, fixture_path) matlab_driver_path.write_text(MATLAB_DRIVER, encoding="utf-8") print(f"[setup] fixture: {fixture_path}") print(f"[setup] driver: {matlab_driver_path}") cmd = [ "matlab", "-batch", f"addpath('{tmp.as_posix()}'); " f"run_matlab_features('{fixture_path.as_posix()}', " f"'{matlab_out_path.as_posix()}', {args.n_rot}, {args.n_feat})", ] print(f"[matlab] running: {' '.join(cmd)}") res = subprocess.run(cmd, capture_output=True, text=True) if res.returncode != 0: print("[matlab] STDOUT:") print(res.stdout[-2000:]) print("[matlab] STDERR:") print(res.stderr[-2000:]) raise SystemExit(f"MATLAB exited with code {res.returncode}") print("[matlab] ok") matlab_data = loadmat(str(matlab_out_path)) matlab_feats = matlab_data["out"] # (n_feat, n_rot, n_mol) print(f"[matlab] output shape: {matlab_feats.shape}") # Run Python ports (loose default-numpy and strict ddof=1) for both print(f"[python] running port ...") py_loose = np.zeros_like(matlab_feats, dtype=np.float64) py_strict = np.zeros_like(matlab_feats, dtype=np.float64) for i, s in enumerate(samples): py_loose[:, :, i] = matlab_angular_features( s, n_rot=args.n_rot, n_feat_target=args.n_feat ) py_strict[:, :, i] = matlab_angular_features_strict( s, n_rot=args.n_rot, n_feat_target=args.n_feat ) print() for label, py_arr in [("loose (ddof=0)", py_loose), ("strict (ddof=1)", py_strict)]: diff = np.abs(matlab_feats - py_arr) max_diff = diff.max() max_idx = np.unravel_index(np.argmax(diff), diff.shape) print(f"[{label}] max |MATLAB - Python| = {max_diff:.3e} at {max_idx}") if max_diff < 1e-10: print(f"[{label}] PASS (< 1e-10)") else: n_bad = int((diff > 1e-10).sum()) row_bad = np.unique(np.argwhere(diff > 1e-10)[:, 0]) print(f"[{label}] FAIL: {n_bad} elements above 1e-10; " f"affected feature rows: {row_bad.tolist()}") mat_val = matlab_feats[max_idx] py_val = py_arr[max_idx] print(f" worst: MATLAB={mat_val:.6e} Python={py_val:.6e}") print() if __name__ == "__main__": main()