da-message-passing / src / experiments / inla_demo.py
inla_demo.py
Raw
import matplotlib.pyplot as plt
from numpy.random import default_rng

from damp import gp, ground_truth_cache, inla_bridge
from damp.gp import Shape


def main() -> None:
    numpy_rng = default_rng(seed=1120)

    prior = gp.get_prior(Shape(24, 16))
    ground_truth = next(ground_truth_cache.load_or_gen(prior))
    obs = gp.choose_observations(
        numpy_rng, n_obs=30, ground_truth=ground_truth, obs_noise=1e-3
    )

    pred_means, pred_stds = inla_bridge.run(prior, obs)

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
    vmin = ground_truth.min()
    vmax = ground_truth.max()
    axes[0].imshow(ground_truth.T, vmin=vmin, vmax=vmax)
    axes[1].imshow(pred_means.T, vmin=vmin, vmax=vmax)
    axes[2].imshow(pred_stds.T)
    obs_xs, obs_ys = zip(*[(x - 1, y - 1) for (x, y), val in obs])
    axes[2].scatter(obs_xs, obs_ys, color="red", s=1)
    plt.savefig("plots/inla_demo.png")
    plt.close()


if __name__ == "__main__":
    main()