"""Compute the marginal variances from the sparse posterior precision matrix using mc. Based off this paper: https://arxiv.org/abs/1705.08656 """ import multiprocessing from math import ceil import matplotlib.pyplot as plt import numpy as np import scipy import scipy.sparse from numpy import ndarray from numpy.random import Generator, default_rng from scipy.sparse import spmatrix from tqdm import tqdm from damp import gp, ground_truth_cache, inla_bridge from damp.gp import Shape def main() -> None: numpy_rng = default_rng(seed=1120) prior = gp.get_prior(Shape(128, 128)) ground_truth = next(ground_truth_cache.load_or_gen(prior)) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, n_obs=1500, ground_truth=ground_truth, obs_noise=obs_noise ) posterior = gp.get_posterior(prior, obs, obs_noise) n_samples = 200 mc_stds = _mc_scipy(numpy_rng, prior, posterior, n_samples) inla_means, inla_stds = inla_bridge.run(prior, obs) vmin = inla_stds.min() vmax = inla_stds.max() fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) axes[0].imshow(inla_stds.T, vmin=vmin, vmax=vmax) axes[1].imshow(mc_stds.reshape(prior.interior_shape).T, vmin=vmin, vmax=vmax) axes[0].set_title("inla") axes[1].set_title("mc") plt.savefig("plots/mc_demo.png") plt.close() def _mc_scipy( rng: Generator, prior: gp.Prior, posterior: gp.Posterior, n_samples: int ) -> ndarray: d = posterior.precision.shape[0] x = _sample_x(rng, prior, posterior, n_samples) n_processes = multiprocessing.cpu_count() pool = multiprocessing.Pool(processes=n_processes) batch_size = int(ceil(d / n_processes)) batches = [ list(range(start, min(start + batch_size, d))) for start in range(0, d, batch_size) ] print(f"Split {d} into {len(batches)} batches of size {batch_size}") results = [] for batch in batches: results.append( pool.apply_async( _compute_marginal_variance_scipy, (batch, posterior, x, n_samples) ) ) sigma2 = np.zeros(shape=(d,)) for result in results: for i, r in result.get(): sigma2[i] = r return np.sqrt(sigma2) def _sample_x( rng: Generator, prior: gp.Prior, posterior: gp.Posterior, n: int ) -> ndarray: Q = posterior.precision G = prior.precision_decomposed H = 1 / posterior.obs_noise * scipy.sparse.diags(posterior.obs_location_mask) d = Q.shape[0] z1 = rng.normal(size=(d, n)) z2 = rng.normal(size=(d, n)) y = G.T @ z1 + H.T @ z2 # reparametrization; sample from N(0, Q) return scipy.sparse.linalg.spsolve(Q, y) def _compute_marginal_variance_scipy( batch: list[int], posterior: gp.Posterior, x: spmatrix, N: int ) -> list[tuple[int, float]]: Q = posterior.precision results = [] for i in tqdm(batch): Qini = scipy.sparse.hstack([Q[i, :i], Q[i, i + 1 :]]) # Q(i, -i) xni = np.vstack((x[:i, :], x[i + 1 :, :])) # x(-i,:) Qini_xni = np.sum(Qini.T.multiply(xni), axis=0) result = 1 / Q[i, i] + np.sum(np.array(Qini_xni) ** 2, axis=1) / ( N * Q[i, i] ** 2 ) results.append((i, result.item())) return results if __name__ == "__main__": main()