import jax.numpy as jnp import matplotlib.pyplot as plt from jax import Array from matplotlib.axes import Axes from numpy.random import default_rng from damp import gp, ground_truth_cache, message_passing from damp.graph import graph_and_diagonal_from_precision_matrix from damp.message_passing import Config def main() -> None: numpy_rng = default_rng(seed=1124) grid_size = 200 interior_size = grid_size - 2 prior = gp.get_prior(grid_size) ground_truth = next(ground_truth_cache.load_or_gen(prior, start_at=0)) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, grid_size, n_obs=800, ground_truth=ground_truth, obs_noise=obs_noise ) posterior = gp.get_posterior(prior, obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision, prior.interior_shape ) config = Config( graph=graph, c=-2.0, Gamma_diagonal=jnp.array(Gamma_diagonal).reshape(interior_size, interior_size), h=jnp.array(posterior.shift).reshape(interior_size, interior_size), lr=0.7, ) initial_edges = message_passing.get_initial_edges(graph, interior_size) fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 2)) history = message_passing.iterate_with_history( config, initial_edges, n_iterations=5000, parallel=True, save_every=100, ) fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 7)) ises = [] ases = [] bses = [] for i, e, m in history: ises.append(i) ases.append(graph.get_all_edges(e.a)) print("is nan", jnp.any(jnp.isnan(ases[-1]))) bses.append(graph.get_all_edges(e.b)) a_abs_max = [jnp.max(jnp.abs(a)) for a in ases] b_abs_max = [jnp.max(jnp.abs(b)) for b in bses] a_abs_mean = [jnp.mean(jnp.abs(a)) for a in ases] b_abs_mean = [jnp.mean(jnp.abs(b)) for b in bses] axes[0, 0].set_title("a") axes[0, 1].set_title("b") axes[0, 0].set_ylabel("count") _hist(jnp.concatenate(ases), axes[0, 0]) _hist(jnp.concatenate(bses), axes[0, 1]) axes[1, 0].set_ylabel("abs max") axes[1, 0].plot(ises, a_abs_max) axes[1, 1].plot(ises, b_abs_max) axes[2, 0].set_ylabel("abs mean") axes[2, 0].plot(ises, a_abs_mean) axes[2, 1].plot(ises, b_abs_mean) plt.tight_layout() plt.savefig("plots/message_dist.png") plt.close() def _hist(xs: Array, ax: Axes) -> None: hist, bin_edges = jnp.histogram(xs, bins=50) ax.bar(bin_edges[:-1], hist, width=bin_edges[1] - bin_edges[0]) if __name__ == "__main__": main()