da-message-passing / src / experiments / inla_sphere.py
inla_sphere.py
Raw
import matplotlib.pyplot as plt
from numpy.random import default_rng
import numpy as np
from damp.gp import Shape
from damp import gp, inla_bridge


def main() -> None:
    numpy_rng = default_rng(seed=1120)

    ground_truth = np.load("../../data/UM/UM_temp.npy")
    lon = np.load("../../data/UM/UM_lon.npy")
    lat = np.load("../../data/UM/UM_lat.npy")

    era5 = np.load("../../data/ERA5/temp_regrid.npy")

    # Subsample a grid for testing
    ratio = 8
    ground_truth = ground_truth[::ratio, ::ratio]
    era5 = era5[::ratio, ::ratio]
    lon = lon[::ratio]
    lat = lat[::ratio]

    # Zero mean the truth based on the climatology mean
    ground_truth = ground_truth - era5

    prior = gp.get_prior_sphere(
        Shape(np.shape(ground_truth)[0], np.shape(ground_truth)[1]), lon, lat
    )
    obs = gp.choose_observations(
        numpy_rng,
        n_obs=round(prior.shape.width * prior.shape.height * 0.1),
        ground_truth=ground_truth,
        obs_noise=1e-3,
    )
    pred_means, pred_stds = inla_bridge.run(prior, obs)
    # Converting back
    ground_truth = ground_truth + era5
    pred_means = pred_means + era5[1:-1, 1:-1]

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
    vmin = ground_truth.min()
    vmax = ground_truth.max()
    axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax)
    axes[1].imshow(np.flipud(pred_means), vmin=vmin, vmax=vmax)
    axes[2].imshow(np.flipud(pred_stds))
    plt.tight_layout()
    plt.savefig("plots/inla_sphere_era5.png", dpi=300)
    plt.close()


if __name__ == "__main__":
    main()