da-message-passing / src / experiments / mc_variances.py
mc_variances.py
Raw
"""Compute the marginal variances from the sparse posterior precision matrix using mc.

Based off this paper: https://arxiv.org/abs/1705.08656
"""
import multiprocessing
from math import ceil

import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy.sparse
from numpy import ndarray
from numpy.random import Generator, default_rng
from scipy.sparse import spmatrix
from tqdm import tqdm

from damp import gp, ground_truth_cache, inla_bridge
from damp.gp import Shape


def main() -> None:
    numpy_rng = default_rng(seed=1120)

    prior = gp.get_prior(Shape(128, 128))
    ground_truth = next(ground_truth_cache.load_or_gen(prior))
    obs_noise = 1e-3
    obs = gp.choose_observations(
        numpy_rng, n_obs=1500, ground_truth=ground_truth, obs_noise=obs_noise
    )
    posterior = gp.get_posterior(prior, obs, obs_noise)

    n_samples = 200
    mc_stds = _mc_scipy(numpy_rng, prior, posterior, n_samples)
    inla_means, inla_stds = inla_bridge.run(prior, obs)

    vmin = inla_stds.min()
    vmax = inla_stds.max()

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    axes[0].imshow(inla_stds.T, vmin=vmin, vmax=vmax)
    axes[1].imshow(mc_stds.reshape(prior.interior_shape).T, vmin=vmin, vmax=vmax)
    axes[0].set_title("inla")
    axes[1].set_title("mc")
    plt.savefig("plots/mc_demo.png")
    plt.close()


def _mc_scipy(
    rng: Generator, prior: gp.Prior, posterior: gp.Posterior, n_samples: int
) -> ndarray:
    d = posterior.precision.shape[0]
    x = _sample_x(rng, prior, posterior, n_samples)

    n_processes = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=n_processes)
    batch_size = int(ceil(d / n_processes))
    batches = [
        list(range(start, min(start + batch_size, d)))
        for start in range(0, d, batch_size)
    ]

    print(f"Split {d} into {len(batches)} batches of size {batch_size}")
    results = []
    for batch in batches:
        results.append(
            pool.apply_async(
                _compute_marginal_variance_scipy, (batch, posterior, x, n_samples)
            )
        )

    sigma2 = np.zeros(shape=(d,))
    for result in results:
        for i, r in result.get():
            sigma2[i] = r

    return np.sqrt(sigma2)


def _sample_x(
    rng: Generator, prior: gp.Prior, posterior: gp.Posterior, n: int
) -> ndarray:
    Q = posterior.precision
    G = prior.precision_decomposed
    H = 1 / posterior.obs_noise * scipy.sparse.diags(posterior.obs_location_mask)

    d = Q.shape[0]
    z1 = rng.normal(size=(d, n))
    z2 = rng.normal(size=(d, n))

    y = G.T @ z1 + H.T @ z2

    # reparametrization; sample from N(0, Q)
    return scipy.sparse.linalg.spsolve(Q, y)


def _compute_marginal_variance_scipy(
    batch: list[int], posterior: gp.Posterior, x: spmatrix, N: int
) -> list[tuple[int, float]]:
    Q = posterior.precision
    results = []
    for i in tqdm(batch):
        Qini = scipy.sparse.hstack([Q[i, :i], Q[i, i + 1 :]])  # Q(i, -i)
        xni = np.vstack((x[:i, :], x[i + 1 :, :]))  # x(-i,:)
        Qini_xni = np.sum(Qini.T.multiply(xni), axis=0)
        result = 1 / Q[i, i] + np.sum(np.array(Qini_xni) ** 2, axis=1) / (
            N * Q[i, i] ** 2
        )
        results.append((i, result.item()))
    return results


if __name__ == "__main__":
    main()