da-message-passing / src / experiments / mp_sphere.py
mp_sphere.py
Raw
from time import time

import jax.numpy as jnp
import matplotlib.pyplot as plt
from numpy.random import default_rng
import numpy as np

from damp import gp, message_passing
from damp.graph import graph_and_diagonal_from_precision_matrix
from damp.message_passing import Config
from damp.gp import Shape


def main() -> None:
    numpy_rng = default_rng(seed=1124)
    ground_truth = np.load("../../data/UM/UM_temp.npy")
    lon = np.load("../../data/UM/UM_lon.npy")
    lat = np.load("../../data/UM/UM_lat.npy")

    era5 = np.load("../../data/ERA5/temp_regrid.npy")

    # Subsample for testing
    ratio = 8
    ground_truth = ground_truth[::ratio, ::ratio]
    era5 = era5[::ratio, ::ratio]
    lon = lon[::ratio]
    lat = lat[::ratio]
    # Zero mean the truth based on the climatology mean
    ground_truth = ground_truth - era5

    prior = gp.get_prior_sphere(
        Shape(np.shape(ground_truth)[0], np.shape(ground_truth)[1]), lon, lat
    )
    obs_noise = 1e-3
    obs = gp.choose_observations(
        numpy_rng,
        n_obs=round(prior.shape.width * prior.shape.height * 0.1),
        ground_truth=ground_truth,
        obs_noise=obs_noise,
    )

    posterior = gp.get_posterior(prior, obs, obs_noise)

    graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
        posterior.precision
    )
    config = Config(
        graph=graph,
        c=-2.0,
        Gamma_diagonal=jnp.array(Gamma_diagonal),
        h=jnp.array(posterior.shift),
        lr=0.7,
    )
    initial_edges = message_passing.get_initial_edges(graph)

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))

    start = time()
    edges, marginals = message_passing.iterate(config, initial_edges, n_iterations=1000)
    end = time()
    print(f"Took {end-start:.2f}s")
    # Converting back
    ground_truth = ground_truth + era5
    mean = np.array(marginals.mean.reshape(prior.interior_shape)) + era5[1:-1, 1:-1]

    vmin = ground_truth.min()
    vmax = ground_truth.max()
    axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax)
    axes[0].set_title("Ground Truth")
    axes[1].imshow(np.flipud(mean), vmin=vmin, vmax=vmax)
    axes[1].set_title("Message Passing")
    axes[2].imshow(np.flipud(marginals.std.reshape(prior.interior_shape)))

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig("plots/mp_sphere.png", dpi=300)
    plt.close()


if __name__ == "__main__":
    main()