da-message-passing / src / experiments_archive / message_dist.py
message_dist.py
Raw
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import Array
from matplotlib.axes import Axes
from numpy.random import default_rng

from damp import gp, ground_truth_cache, message_passing
from damp.graph import graph_and_diagonal_from_precision_matrix
from damp.message_passing import Config


def main() -> None:
    numpy_rng = default_rng(seed=1124)
    grid_size = 200
    interior_size = grid_size - 2

    prior = gp.get_prior(grid_size)
    ground_truth = next(ground_truth_cache.load_or_gen(prior, start_at=0))
    obs_noise = 1e-3
    obs = gp.choose_observations(
        numpy_rng, grid_size, n_obs=800, ground_truth=ground_truth, obs_noise=obs_noise
    )
    posterior = gp.get_posterior(prior, obs, obs_noise)

    graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
        posterior.precision, prior.interior_shape
    )

    config = Config(
        graph=graph,
        c=-2.0,
        Gamma_diagonal=jnp.array(Gamma_diagonal).reshape(interior_size, interior_size),
        h=jnp.array(posterior.shift).reshape(interior_size, interior_size),
        lr=0.7,
    )
    initial_edges = message_passing.get_initial_edges(graph, interior_size)

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2))
    history = message_passing.iterate_with_history(
        config,
        initial_edges,
        n_iterations=5000,
        parallel=True,
        save_every=100,
    )

    fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 7))

    ises = []
    ases = []
    bses = []
    for i, e, m in history:
        ises.append(i)
        ases.append(graph.get_all_edges(e.a))
        print("is nan", jnp.any(jnp.isnan(ases[-1])))
        bses.append(graph.get_all_edges(e.b))

    a_abs_max = [jnp.max(jnp.abs(a)) for a in ases]
    b_abs_max = [jnp.max(jnp.abs(b)) for b in bses]
    a_abs_mean = [jnp.mean(jnp.abs(a)) for a in ases]
    b_abs_mean = [jnp.mean(jnp.abs(b)) for b in bses]

    axes[0, 0].set_title("a")
    axes[0, 1].set_title("b")

    axes[0, 0].set_ylabel("count")
    _hist(jnp.concatenate(ases), axes[0, 0])
    _hist(jnp.concatenate(bses), axes[0, 1])

    axes[1, 0].set_ylabel("abs max")
    axes[1, 0].plot(ises, a_abs_max)
    axes[1, 1].plot(ises, b_abs_max)

    axes[2, 0].set_ylabel("abs mean")
    axes[2, 0].plot(ises, a_abs_mean)
    axes[2, 1].plot(ises, b_abs_mean)

    plt.tight_layout()
    plt.savefig("plots/message_dist.png")
    plt.close()


def _hist(xs: Array, ax: Axes) -> None:
    hist, bin_edges = jnp.histogram(xs, bins=50)
    ax.bar(bin_edges[:-1], hist, width=bin_edges[1] - bin_edges[0])


if __name__ == "__main__":
    main()