da-message-passing / src / experiments / threedvar_sphere.py
threedvar_sphere.py
Raw
from time import time

import jax.random
import matplotlib.pyplot as plt
from jax.random import PRNGKey
from numpy.random import default_rng
import numpy as np

import damp.threedvar as threedvar
from damp.gp import Shape
from damp import gp


def main() -> None:
    numpy_rng = default_rng(seed=1124)

    ground_truth = np.load("../../data/UM/UM_temp.npy")
    lon = np.load("../../data/UM/UM_lon.npy")
    lat = np.load("../../data/UM/UM_lat.npy")

    era5 = np.load("../../data/ERA5/temp_regrid.npy")

    # Subsample a grid for testing
    ratio = 8
    ground_truth = ground_truth[::ratio, ::ratio]
    era5 = era5[::ratio, ::ratio]
    lon = lon[::ratio]
    lat = lat[::ratio]

    obs_noise = 1e-3
    # Zero mean the truth based on the climatology mean
    ground_truth = ground_truth - era5

    prior = gp.get_prior_sphere(
        Shape(np.shape(ground_truth)[0], np.shape(ground_truth)[1]), lon, lat
    )

    obs = gp.choose_observations(
        numpy_rng,
        n_obs=round(prior.shape.width * prior.shape.height * 0.1),
        ground_truth=ground_truth,
        obs_noise=obs_noise,
    )

    rng = PRNGKey(seed=23142834)
    rng, rng_input = jax.random.split(rng)

    start = time()
    result = threedvar.run_optimizer(rng_input, prior, obs, obs_noise)
    end = time()
    print(f"Took {end-start:.2f}s")

    # Converting back
    ground_truth = ground_truth + era5
    result = result + era5[1:-1, 1:-1]

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    vmin = ground_truth.min()
    vmax = ground_truth.max()
    axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax)
    axes[1].imshow(np.flipud(result), vmin=vmin, vmax=vmax)
    plt.tight_layout()
    plt.savefig("plots/3dvar_sphere_era5.png", dpi=300)
    plt.close()


if __name__ == "__main__":
    main()