import time import jax.numpy as jnp import matplotlib.pyplot as plt from numpy.random import default_rng import numpy as np from damp import gp, ground_truth_cache, metrics, multigrid from damp.gp import Shape def main() -> None: numpy_rng = default_rng(seed=1124) ground_truth = np.load("../../data/UM/UM_temp.npy") lon = np.load("../../data/UM/UM_lon.npy") lat = np.load("../../data/UM/UM_lat.npy") era5 = np.load("../../data/ERA5/temp_regrid.npy") # Subsample for testing ratio = 1 ground_truth = ground_truth[::ratio, ::ratio] era5 = era5[::ratio, ::ratio] lon = lon[::ratio] lat = lat[::ratio] # Zero mean the truth based on the climatology mean ground_truth = ground_truth - era5 width = np.shape(ground_truth)[0] length = np.shape(ground_truth)[1] prior = gp.get_prior_sphere(Shape(width, length), lon, lat) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, n_obs=round(prior.shape.flatten() * 0.05), ground_truth=ground_truth, obs_noise=obs_noise, ) # Set up levels ratios = [4, 2, 1] levels = [Shape(int(width / level), int(length / level)) for level in ratios] start_time = time.time() output = multigrid.run_on_sphere( prior, obs, obs_noise, lat=lat, lon=lon, ratios=ratios, levels=levels, c=-2.0, lr=0.7, ) end_time = time.time() print("Total Runtime = ", (end_time - start_time)) _, _, final_marginals = output[-1] print(f"RMSE = {metrics.rmse(final_marginals.mean, ground_truth).item()}") for i, (level, iterations, level_marginals) in enumerate(output): fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(4, 2)) vmin = ground_truth.min() vmax = ground_truth.max() axes[0].imshow(np.flipud(ground_truth), vmin=vmin, vmax=vmax) axes[1].imshow( np.flipud(jnp.pad(level_marginals.mean, 1)), vmin=vmin, vmax=vmax ) axes[0].set_title("Ground Truth", fontsize=8) axes[1].set_title( f"Multigrid (level = {level}) \n {iterations} Iterations", fontsize=8 ) for ax in axes.flatten(): ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() plt.savefig(f"plots/multigrid/multigrid_{i}.png", dpi=300) plt.close() if __name__ == "__main__": main()