from typing import Sequence import jax.numpy as jnp import matplotlib.pyplot as plt from jax import Array from matplotlib import cm from matplotlib.pyplot import Axes from numpy.random import default_rng from damp import gp, ground_truth_cache, message_passing from damp.graph import graph_and_diagonal_from_precision_matrix from damp.message_passing import Config def main() -> None: n_obses = [20, 200] fig, axes = plt.subplots( nrows=len(n_obses), ncols=2, figsize=(10, 5), squeeze=False ) axes[0][0].set_title("shift") axes[0][1].set_title("precision") for n_obs, ax in zip(n_obses, axes): run(ax, n_obs) plt.tight_layout() plt.savefig("plots/residuals.png") plt.close() def run(axs: Sequence[Axes], n_obs: int) -> None: numpy_rng = default_rng(seed=1124) grid_size = 50 interior_size = grid_size - 2 prior = gp.get_prior(grid_size) ground_truth = next(ground_truth_cache.load_or_gen(prior)) obs = gp.choose_observations(numpy_rng, n_obs, ground_truth, obs_noise=1e-3) posterior = gp.get_posterior(prior, obs) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision, prior.interior_shape ) config = Config( graph=graph, c=-2.0, Gamma_diagonal=jnp.array(Gamma_diagonal).reshape(interior_size, interior_size), h=jnp.array(posterior.shift).reshape(interior_size, interior_size), lr=0.7, ) initial_edges = message_passing.get_initial_edges(graph, interior_size) history = message_passing.iterate_with_history( config, initial_edges, n_iterations=500, parallel=True, save_every=1, ) indices = [i for i, _, _ in history] # shifts = jnp.stack([e.shift for _, e, _ in history]) # shifts = shifts.reshape(len(history), -1) # precisions = jnp.stack([e.precision for _, e, _ in history]) # precisions = precisions.reshape(len(history), -1) # _plot_dist(axs[0], shifts) # _plot_dist(axs[1], precisions) means = jnp.stack([n.mean for _, e, n in history]) means = means.reshape(len(history), -1) stds = jnp.stack([n.std for _, e, n in history]) stds = stds.reshape(len(history), -1) _plot_dist(axs[0], means) _plot_dist(axs[1], stds) obs_percent = n_obs / (grid_size * grid_size) * 100 axs[0].set_ylabel(f"n obs = {obs_percent:.1f}%") def _plot_dist(ax: Axes, xs: Array) -> None: diffs = jnp.diff(xs, axis=1) diffs = diffs[:, ~jnp.isnan(diffs[-1])] diffs = jnp.abs(diffs) qs = jnp.linspace(0.0, 100.0, 10) colors = cm.viridis(jnp.linspace(0.0, 1.0, len(qs))) percentiles = jnp.percentile(diffs, q=qs, axis=1) xs = list(range(percentiles.shape[1])) for i in range(len(percentiles)): if i == 0: ax.plot(xs, percentiles[i], color=colors[i]) else: ax.fill_between(xs, percentiles[i - 1], percentiles[i], color=colors[i]) ax.set_xlim(left=0.0) ax.set_ylim(bottom=0.0, top=jnp.max(diffs)) if __name__ == "__main__": main()