da-message-passing / src / experiments / mp_multigrid_sphere.py
mp_multigrid_sphere.py
Raw
import time

import jax.numpy as jnp
import matplotlib.pyplot as plt
from numpy.random import default_rng
import numpy as np

from damp import gp, ground_truth_cache, metrics, multigrid
from damp.gp import Shape


def main() -> None:
    numpy_rng = default_rng(seed=1124)
    ground_truth = np.load("../../data/UM/UM_temp.npy")
    lon = np.load("../../data/UM/UM_lon.npy")
    lat = np.load("../../data/UM/UM_lat.npy")

    era5 = np.load("../../data/ERA5/temp_regrid.npy")

    # Subsample for testing
    ratio = 1
    ground_truth = ground_truth[::ratio, ::ratio]
    era5 = era5[::ratio, ::ratio]
    lon = lon[::ratio]
    lat = lat[::ratio]
    # Zero mean the truth based on the climatology mean
    ground_truth = ground_truth - era5

    width = np.shape(ground_truth)[0]
    length = np.shape(ground_truth)[1]

    prior = gp.get_prior_sphere(Shape(width, length), lon, lat)

    obs_noise = 1e-3
    obs = gp.choose_observations(
        numpy_rng,
        n_obs=round(prior.shape.flatten() * 0.05),
        ground_truth=ground_truth,
        obs_noise=obs_noise,
    )
    # Set up levels
    ratios = [4, 2, 1]
    levels = [Shape(int(width / level), int(length / level)) for level in ratios]
    start_time = time.time()
    output = multigrid.run_on_sphere(
        prior,
        obs,
        obs_noise,
        lat=lat,
        lon=lon,
        ratios=ratios,
        levels=levels,
        c=-2.0,
        lr=0.7,
    )
    end_time = time.time()
    print("Total Runtime = ", (end_time - start_time))

    _, _, final_marginals = output[-1]
    print(f"RMSE = {metrics.rmse(final_marginals.mean, ground_truth).item()}")

    for i, (level, iterations, level_marginals) in enumerate(output):
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(4, 2))
        vmin = ground_truth.min()
        vmax = ground_truth.max()

        axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax)
        axes[1].imshow(
            np.flipud(jnp.pad(level_marginals.mean, 1)), vmin=vmin, vmax=vmax
        )
        axes[0].set_title("Ground Truth", fontsize=8)
        axes[1].set_title(
            f"Multigrid (level = {level}) \n {iterations} Iterations", fontsize=8
        )
        for ax in axes.flatten():
            ax.set_xticks([])
            ax.set_yticks([])
        plt.tight_layout()
        plt.savefig(f"plots/multigrid/multigrid_{i}.png", dpi=300)
        plt.close()


if __name__ == "__main__":
    main()