da-message-passing / src / experiments_archive / residuals.py
residuals.py
Raw
from typing import Sequence

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import Array
from matplotlib import cm
from matplotlib.pyplot import Axes
from numpy.random import default_rng

from damp import gp, ground_truth_cache, message_passing
from damp.graph import graph_and_diagonal_from_precision_matrix
from damp.message_passing import Config


def main() -> None:
    n_obses = [20, 200]
    fig, axes = plt.subplots(
        nrows=len(n_obses), ncols=2, figsize=(10, 5), squeeze=False
    )
    axes[0][0].set_title("shift")
    axes[0][1].set_title("precision")
    for n_obs, ax in zip(n_obses, axes):
        run(ax, n_obs)
    plt.tight_layout()
    plt.savefig("plots/residuals.png")
    plt.close()


def run(axs: Sequence[Axes], n_obs: int) -> None:
    numpy_rng = default_rng(seed=1124)
    grid_size = 50
    interior_size = grid_size - 2

    prior = gp.get_prior(grid_size)
    ground_truth = next(ground_truth_cache.load_or_gen(prior))
    obs = gp.choose_observations(numpy_rng, n_obs, ground_truth, obs_noise=1e-3)
    posterior = gp.get_posterior(prior, obs)

    graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
        posterior.precision, prior.interior_shape
    )

    config = Config(
        graph=graph,
        c=-2.0,
        Gamma_diagonal=jnp.array(Gamma_diagonal).reshape(interior_size, interior_size),
        h=jnp.array(posterior.shift).reshape(interior_size, interior_size),
        lr=0.7,
    )
    initial_edges = message_passing.get_initial_edges(graph, interior_size)

    history = message_passing.iterate_with_history(
        config,
        initial_edges,
        n_iterations=500,
        parallel=True,
        save_every=1,
    )
    indices = [i for i, _, _ in history]
    # shifts = jnp.stack([e.shift for _, e, _ in history])
    # shifts = shifts.reshape(len(history), -1)
    # precisions = jnp.stack([e.precision for _, e, _ in history])
    # precisions = precisions.reshape(len(history), -1)
    # _plot_dist(axs[0], shifts)
    # _plot_dist(axs[1], precisions)

    means = jnp.stack([n.mean for _, e, n in history])
    means = means.reshape(len(history), -1)
    stds = jnp.stack([n.std for _, e, n in history])
    stds = stds.reshape(len(history), -1)
    _plot_dist(axs[0], means)
    _plot_dist(axs[1], stds)

    obs_percent = n_obs / (grid_size * grid_size) * 100
    axs[0].set_ylabel(f"n obs = {obs_percent:.1f}%")


def _plot_dist(ax: Axes, xs: Array) -> None:
    diffs = jnp.diff(xs, axis=1)
    diffs = diffs[:, ~jnp.isnan(diffs[-1])]
    diffs = jnp.abs(diffs)
    qs = jnp.linspace(0.0, 100.0, 10)
    colors = cm.viridis(jnp.linspace(0.0, 1.0, len(qs)))
    percentiles = jnp.percentile(diffs, q=qs, axis=1)
    xs = list(range(percentiles.shape[1]))
    for i in range(len(percentiles)):
        if i == 0:
            ax.plot(xs, percentiles[i], color=colors[i])
        else:
            ax.fill_between(xs, percentiles[i - 1], percentiles[i], color=colors[i])
    ax.set_xlim(left=0.0)
    ax.set_ylim(bottom=0.0, top=jnp.max(diffs))


if __name__ == "__main__":
    main()