"""Numerical-equivalence test for the MATLAB angular_features port.
Writes a fixture of 100 synthetic molecules to a .mat file, calls the
MATLAB QM9_experiment.angular_features on each, then compares against
the Python port element-wise.
Pass criterion: max |MATLAB - Python| < 1e-10 per element.
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from pathlib import Path
import numpy as np
from scipy.io import loadmat, savemat
# Local import; ensure parent dir on path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from data.matlab_angular_features import (
matlab_angular_features,
matlab_angular_features_strict,
)
class _Sample:
"""Minimal stand-in for QM9Sample."""
def __init__(self, coords, Z):
self.coords = coords
self.Z = Z
def make_fixture(n_mol: int = 100, max_atoms: int = 12, seed: int = 1234):
"""Build n_mol synthetic molecules with random coords + Z in {1,6,7,8,9}."""
rng = np.random.default_rng(seed)
samples = []
for _ in range(n_mol):
n_atoms = int(rng.integers(2, max_atoms + 1))
coords = rng.normal(size=(n_atoms, 3)) * 1.5
Z = rng.choice([1, 6, 7, 8, 9], size=n_atoms).astype(np.float64)
samples.append(_Sample(coords=coords, Z=Z))
return samples
def write_mat_fixture(samples, path: Path):
"""Write the fixture to a .mat file MATLAB can load."""
coords_cell = np.empty(len(samples), dtype=object)
Z_cell = np.empty(len(samples), dtype=object)
n_atoms_arr = np.zeros(len(samples), dtype=np.int32)
for i, s in enumerate(samples):
coords_cell[i] = s.coords
Z_cell[i] = s.Z.reshape(-1, 1)
n_atoms_arr[i] = s.coords.shape[0]
savemat(str(path), {
"coords_cell": coords_cell,
"Z_cell": Z_cell,
"n_atoms_arr": n_atoms_arr,
})
MATLAB_DRIVER = r"""
function run_matlab_features(in_path, out_path, n_rot, n_feat_target)
addpath(fullfile(fileparts(mfilename('fullpath')), '..', '..', '..', 'experiments'));
s = load(in_path);
n_mol = numel(s.coords_cell);
out = zeros(n_feat_target, n_rot, n_mol);
angles = (0:n_rot-1) * 2 * pi / n_rot;
% Inline copy of the MATLAB angular_features body so we don't
% depend on instantiating the QM9_experiment class (which loads
% real QM9 data on construction).
for mi = 1:n_mol
coords = s.coords_cell{mi};
Z = s.Z_cell{mi};
out(:, :, mi) = angular_features_inline(coords, Z, angles, n_feat_target);
end
save(out_path, 'out', '-v7');
end
function F = angular_features_inline(coords, Z, angles, n_feat_target)
n_rot = length(angles);
n_atoms = size(coords, 1);
Zn = Z(:);
w = Zn / (sum(Zn) + 1e-10);
coords_c = coords - mean(coords, 1);
feat_list = {};
e1 = [cos(angles); sin(angles); zeros(1, n_rot)];
e2 = [-sin(angles); cos(angles); zeros(1, n_rot)];
p1 = coords_c * e1;
p2 = coords_c * e2;
pz = coords_c * repmat([0; 0; 1], 1, n_rot);
feat_list{end+1} = w' * p1;
feat_list{end+1} = w' * p2;
feat_list{end+1} = w' * pz;
feat_list{end+1} = w' * (p1.^2);
feat_list{end+1} = w' * (p2.^2);
feat_list{end+1} = w' * (pz.^2);
feat_list{end+1} = w' * (p1 .* p2);
feat_list{end+1} = w' * (p1 .* pz);
feat_list{end+1} = w' * (p2 .* pz);
feat_list{end+1} = Zn' * p1;
feat_list{end+1} = Zn' * p2;
feat_list{end+1} = Zn' * (p1.^2);
feat_list{end+1} = Zn' * (p2.^2);
feat_list{end+1} = Zn' * (p1 .* p2);
feat_list{end+1} = Zn' * (p1 .* pz);
feat_list{end+1} = w' * (p1.^3);
feat_list{end+1} = w' * (p2.^3);
feat_list{end+1} = Zn' * (p1.^3);
feat_list{end+1} = Zn' * (p2.^3);
[~, si] = sort(Z, 'descend');
for k = 1:min(4, n_atoms)
idx = si(k);
feat_list{end+1} = p1(idx, :);
feat_list{end+1} = p2(idx, :);
feat_list{end+1} = pz(idx, :);
end
for pp = 1:min(3, n_atoms - 1)
ii = si(1);
jj = si(min(pp + 1, n_atoms));
d_ij = norm(coords_c(ii, :) - coords_c(jj, :)) + 1e-8;
feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p1(ii, :) - p1(jj, :));
feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p2(ii, :) - p2(jj, :));
end
if n_atoms >= 2
D = pdist(coords_c);
for v = [mean(D), std(D), min(D), max(D)]
feat_list{end+1} = repmat(v, 1, n_rot);
end
else
for k = 1:4
feat_list{end+1} = zeros(1, n_rot);
end
end
r = sqrt(sum(coords_c.^2, 2));
rs = sort(r, 'descend');
for k = 1:min(4, n_atoms)
feat_list{end+1} = repmat(rs(k), 1, n_rot);
end
for k = n_atoms+1 : 4
feat_list{end+1} = zeros(1, n_rot);
end
feat_list{end+1} = repmat(sum(Zn.^2) / 100, 1, n_rot);
feat_list{end+1} = repmat(mean(Zn), 1, n_rot);
feat_list{end+1} = repmat(n_atoms, 1, n_rot);
F = cell2mat(feat_list');
nr = size(F, 1);
if nr < n_feat_target
F = [F; zeros(n_feat_target - nr, n_rot)];
elseif nr > n_feat_target
F = F(1:n_feat_target, :);
end
end
"""
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--n_mol", type=int, default=100)
ap.add_argument("--n_rot", type=int, default=12)
ap.add_argument("--n_feat", type=int, default=14)
ap.add_argument("--tmp_dir", default="C:/Temp/starg_eqtest")
args = ap.parse_args()
tmp = Path(args.tmp_dir)
tmp.mkdir(parents=True, exist_ok=True)
fixture_path = tmp / "fixture.mat"
matlab_out_path = tmp / "matlab_out.mat"
matlab_driver_path = tmp / "run_matlab_features.m"
print(f"[setup] generating {args.n_mol} synthetic molecules ...")
samples = make_fixture(n_mol=args.n_mol)
write_mat_fixture(samples, fixture_path)
matlab_driver_path.write_text(MATLAB_DRIVER, encoding="utf-8")
print(f"[setup] fixture: {fixture_path}")
print(f"[setup] driver: {matlab_driver_path}")
cmd = [
"matlab", "-batch",
f"addpath('{tmp.as_posix()}'); "
f"run_matlab_features('{fixture_path.as_posix()}', "
f"'{matlab_out_path.as_posix()}', {args.n_rot}, {args.n_feat})",
]
print(f"[matlab] running: {' '.join(cmd)}")
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
print("[matlab] STDOUT:")
print(res.stdout[-2000:])
print("[matlab] STDERR:")
print(res.stderr[-2000:])
raise SystemExit(f"MATLAB exited with code {res.returncode}")
print("[matlab] ok")
matlab_data = loadmat(str(matlab_out_path))
matlab_feats = matlab_data["out"] # (n_feat, n_rot, n_mol)
print(f"[matlab] output shape: {matlab_feats.shape}")
# Run Python ports (loose default-numpy and strict ddof=1) for both
print(f"[python] running port ...")
py_loose = np.zeros_like(matlab_feats, dtype=np.float64)
py_strict = np.zeros_like(matlab_feats, dtype=np.float64)
for i, s in enumerate(samples):
py_loose[:, :, i] = matlab_angular_features(
s, n_rot=args.n_rot, n_feat_target=args.n_feat
)
py_strict[:, :, i] = matlab_angular_features_strict(
s, n_rot=args.n_rot, n_feat_target=args.n_feat
)
print()
for label, py_arr in [("loose (ddof=0)", py_loose), ("strict (ddof=1)", py_strict)]:
diff = np.abs(matlab_feats - py_arr)
max_diff = diff.max()
max_idx = np.unravel_index(np.argmax(diff), diff.shape)
print(f"[{label}] max |MATLAB - Python| = {max_diff:.3e} at {max_idx}")
if max_diff < 1e-10:
print(f"[{label}] PASS (< 1e-10)")
else:
n_bad = int((diff > 1e-10).sum())
row_bad = np.unique(np.argwhere(diff > 1e-10)[:, 0])
print(f"[{label}] FAIL: {n_bad} elements above 1e-10; "
f"affected feature rows: {row_bad.tolist()}")
mat_val = matlab_feats[max_idx]
py_val = py_arr[max_idx]
print(f" worst: MATLAB={mat_val:.6e} Python={py_val:.6e}")
print()
if __name__ == "__main__":
main()