tensor-group-sym / python / large_scale / data / test_matlab_equivalence.py
test_matlab_equivalence.py
Raw
"""Numerical-equivalence test for the MATLAB angular_features port.

Writes a fixture of 100 synthetic molecules to a .mat file, calls the
MATLAB QM9_experiment.angular_features on each, then compares against
the Python port element-wise.

Pass criterion: max |MATLAB - Python| < 1e-10 per element.
"""

from __future__ import annotations

import argparse
import subprocess
import sys
from pathlib import Path

import numpy as np
from scipy.io import loadmat, savemat

# Local import; ensure parent dir on path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from data.matlab_angular_features import (
    matlab_angular_features,
    matlab_angular_features_strict,
)


class _Sample:
    """Minimal stand-in for QM9Sample."""
    def __init__(self, coords, Z):
        self.coords = coords
        self.Z = Z


def make_fixture(n_mol: int = 100, max_atoms: int = 12, seed: int = 1234):
    """Build n_mol synthetic molecules with random coords + Z in {1,6,7,8,9}."""
    rng = np.random.default_rng(seed)
    samples = []
    for _ in range(n_mol):
        n_atoms = int(rng.integers(2, max_atoms + 1))
        coords = rng.normal(size=(n_atoms, 3)) * 1.5
        Z = rng.choice([1, 6, 7, 8, 9], size=n_atoms).astype(np.float64)
        samples.append(_Sample(coords=coords, Z=Z))
    return samples


def write_mat_fixture(samples, path: Path):
    """Write the fixture to a .mat file MATLAB can load."""
    coords_cell = np.empty(len(samples), dtype=object)
    Z_cell = np.empty(len(samples), dtype=object)
    n_atoms_arr = np.zeros(len(samples), dtype=np.int32)
    for i, s in enumerate(samples):
        coords_cell[i] = s.coords
        Z_cell[i] = s.Z.reshape(-1, 1)
        n_atoms_arr[i] = s.coords.shape[0]
    savemat(str(path), {
        "coords_cell": coords_cell,
        "Z_cell": Z_cell,
        "n_atoms_arr": n_atoms_arr,
    })


MATLAB_DRIVER = r"""
function run_matlab_features(in_path, out_path, n_rot, n_feat_target)
    addpath(fullfile(fileparts(mfilename('fullpath')), '..', '..', '..', 'experiments'));
    s = load(in_path);
    n_mol = numel(s.coords_cell);
    out = zeros(n_feat_target, n_rot, n_mol);
    angles = (0:n_rot-1) * 2 * pi / n_rot;

    % Inline copy of the MATLAB angular_features body so we don't
    % depend on instantiating the QM9_experiment class (which loads
    % real QM9 data on construction).
    for mi = 1:n_mol
        coords = s.coords_cell{mi};
        Z = s.Z_cell{mi};
        out(:, :, mi) = angular_features_inline(coords, Z, angles, n_feat_target);
    end
    save(out_path, 'out', '-v7');
end

function F = angular_features_inline(coords, Z, angles, n_feat_target)
    n_rot = length(angles);
    n_atoms = size(coords, 1);
    Zn = Z(:);
    w = Zn / (sum(Zn) + 1e-10);
    coords_c = coords - mean(coords, 1);
    feat_list = {};

    e1 = [cos(angles); sin(angles); zeros(1, n_rot)];
    e2 = [-sin(angles); cos(angles); zeros(1, n_rot)];

    p1 = coords_c * e1;
    p2 = coords_c * e2;
    pz = coords_c * repmat([0; 0; 1], 1, n_rot);

    feat_list{end+1} = w' * p1;
    feat_list{end+1} = w' * p2;
    feat_list{end+1} = w' * pz;
    feat_list{end+1} = w' * (p1.^2);
    feat_list{end+1} = w' * (p2.^2);
    feat_list{end+1} = w' * (pz.^2);
    feat_list{end+1} = w' * (p1 .* p2);
    feat_list{end+1} = w' * (p1 .* pz);
    feat_list{end+1} = w' * (p2 .* pz);
    feat_list{end+1} = Zn' * p1;
    feat_list{end+1} = Zn' * p2;
    feat_list{end+1} = Zn' * (p1.^2);
    feat_list{end+1} = Zn' * (p2.^2);
    feat_list{end+1} = Zn' * (p1 .* p2);
    feat_list{end+1} = Zn' * (p1 .* pz);
    feat_list{end+1} = w' * (p1.^3);
    feat_list{end+1} = w' * (p2.^3);
    feat_list{end+1} = Zn' * (p1.^3);
    feat_list{end+1} = Zn' * (p2.^3);

    [~, si] = sort(Z, 'descend');
    for k = 1:min(4, n_atoms)
        idx = si(k);
        feat_list{end+1} = p1(idx, :);
        feat_list{end+1} = p2(idx, :);
        feat_list{end+1} = pz(idx, :);
    end
    for pp = 1:min(3, n_atoms - 1)
        ii = si(1);
        jj = si(min(pp + 1, n_atoms));
        d_ij = norm(coords_c(ii, :) - coords_c(jj, :)) + 1e-8;
        feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p1(ii, :) - p1(jj, :));
        feat_list{end+1} = Z(ii) * Z(jj) / d_ij * (p2(ii, :) - p2(jj, :));
    end

    if n_atoms >= 2
        D = pdist(coords_c);
        for v = [mean(D), std(D), min(D), max(D)]
            feat_list{end+1} = repmat(v, 1, n_rot);
        end
    else
        for k = 1:4
            feat_list{end+1} = zeros(1, n_rot);
        end
    end

    r = sqrt(sum(coords_c.^2, 2));
    rs = sort(r, 'descend');
    for k = 1:min(4, n_atoms)
        feat_list{end+1} = repmat(rs(k), 1, n_rot);
    end
    for k = n_atoms+1 : 4
        feat_list{end+1} = zeros(1, n_rot);
    end
    feat_list{end+1} = repmat(sum(Zn.^2) / 100, 1, n_rot);
    feat_list{end+1} = repmat(mean(Zn), 1, n_rot);
    feat_list{end+1} = repmat(n_atoms, 1, n_rot);

    F = cell2mat(feat_list');
    nr = size(F, 1);
    if nr < n_feat_target
        F = [F; zeros(n_feat_target - nr, n_rot)];
    elseif nr > n_feat_target
        F = F(1:n_feat_target, :);
    end
end
"""


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--n_mol", type=int, default=100)
    ap.add_argument("--n_rot", type=int, default=12)
    ap.add_argument("--n_feat", type=int, default=14)
    ap.add_argument("--tmp_dir", default="C:/Temp/starg_eqtest")
    args = ap.parse_args()

    tmp = Path(args.tmp_dir)
    tmp.mkdir(parents=True, exist_ok=True)
    fixture_path = tmp / "fixture.mat"
    matlab_out_path = tmp / "matlab_out.mat"
    matlab_driver_path = tmp / "run_matlab_features.m"

    print(f"[setup] generating {args.n_mol} synthetic molecules ...")
    samples = make_fixture(n_mol=args.n_mol)
    write_mat_fixture(samples, fixture_path)
    matlab_driver_path.write_text(MATLAB_DRIVER, encoding="utf-8")
    print(f"[setup] fixture: {fixture_path}")
    print(f"[setup] driver:  {matlab_driver_path}")

    cmd = [
        "matlab", "-batch",
        f"addpath('{tmp.as_posix()}'); "
        f"run_matlab_features('{fixture_path.as_posix()}', "
        f"'{matlab_out_path.as_posix()}', {args.n_rot}, {args.n_feat})",
    ]
    print(f"[matlab] running: {' '.join(cmd)}")
    res = subprocess.run(cmd, capture_output=True, text=True)
    if res.returncode != 0:
        print("[matlab] STDOUT:")
        print(res.stdout[-2000:])
        print("[matlab] STDERR:")
        print(res.stderr[-2000:])
        raise SystemExit(f"MATLAB exited with code {res.returncode}")
    print("[matlab] ok")

    matlab_data = loadmat(str(matlab_out_path))
    matlab_feats = matlab_data["out"]  # (n_feat, n_rot, n_mol)
    print(f"[matlab] output shape: {matlab_feats.shape}")

    # Run Python ports (loose default-numpy and strict ddof=1) for both
    print(f"[python] running port ...")
    py_loose = np.zeros_like(matlab_feats, dtype=np.float64)
    py_strict = np.zeros_like(matlab_feats, dtype=np.float64)
    for i, s in enumerate(samples):
        py_loose[:, :, i] = matlab_angular_features(
            s, n_rot=args.n_rot, n_feat_target=args.n_feat
        )
        py_strict[:, :, i] = matlab_angular_features_strict(
            s, n_rot=args.n_rot, n_feat_target=args.n_feat
        )

    print()
    for label, py_arr in [("loose (ddof=0)", py_loose), ("strict (ddof=1)", py_strict)]:
        diff = np.abs(matlab_feats - py_arr)
        max_diff = diff.max()
        max_idx = np.unravel_index(np.argmax(diff), diff.shape)
        print(f"[{label}] max |MATLAB - Python| = {max_diff:.3e} at {max_idx}")
        if max_diff < 1e-10:
            print(f"[{label}] PASS (< 1e-10)")
        else:
            n_bad = int((diff > 1e-10).sum())
            row_bad = np.unique(np.argwhere(diff > 1e-10)[:, 0])
            print(f"[{label}] FAIL: {n_bad} elements above 1e-10; "
                  f"affected feature rows: {row_bad.tolist()}")
            mat_val = matlab_feats[max_idx]
            py_val = py_arr[max_idx]
            print(f"      worst: MATLAB={mat_val:.6e} Python={py_val:.6e}")
    print()


if __name__ == "__main__":
    main()