from time import time import jax.numpy as jnp import matplotlib.pyplot as plt from numpy.random import default_rng import numpy as np from damp import gp, message_passing from damp.graph import graph_and_diagonal_from_precision_matrix from damp.message_passing import Config from damp.gp import Shape def main() -> None: numpy_rng = default_rng(seed=1124) ground_truth = np.load("../../data/UM/UM_temp.npy") lon = np.load("../../data/UM/UM_lon.npy") lat = np.load("../../data/UM/UM_lat.npy") era5 = np.load("../../data/ERA5/temp_regrid.npy") # Subsample for testing ratio = 8 ground_truth = ground_truth[::ratio, ::ratio] era5 = era5[::ratio, ::ratio] lon = lon[::ratio] lat = lat[::ratio] # Zero mean the truth based on the climatology mean ground_truth = ground_truth - era5 prior = gp.get_prior_sphere( Shape(np.shape(ground_truth)[0], np.shape(ground_truth)[1]), lon, lat ) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, n_obs=round(prior.shape.width * prior.shape.height * 0.1), ground_truth=ground_truth, obs_noise=obs_noise, ) posterior = gp.get_posterior(prior, obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision ) config = Config( graph=graph, c=-2.0, Gamma_diagonal=jnp.array(Gamma_diagonal), h=jnp.array(posterior.shift), lr=0.7, ) initial_edges = message_passing.get_initial_edges(graph) fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5)) start = time() edges, marginals = message_passing.iterate(config, initial_edges, n_iterations=1000) end = time() print(f"Took {end-start:.2f}s") # Converting back ground_truth = ground_truth + era5 mean = np.array(marginals.mean.reshape(prior.interior_shape)) + era5[1:-1, 1:-1] vmin = ground_truth.min() vmax = ground_truth.max() axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax) axes[0].set_title("Ground Truth") axes[1].imshow(np.flipud(mean), vmin=vmin, vmax=vmax) axes[1].set_title("Message Passing") axes[2].imshow(np.flipud(marginals.std.reshape(prior.interior_shape))) for ax in axes.flatten(): ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() plt.savefig("plots/mp_sphere.png", dpi=300) plt.close() if __name__ == "__main__": main()