from collections import defaultdict from functools import partial import jax.numpy as jnp import numpy as np from jax import Array, jit from numpy import ndarray from damp import gp, message_passing from damp.gp import Obs, Prior, Shape from damp.graph import graph_and_diagonal_from_precision_matrix from damp.message_passing import Config, Edges, Marginals ObsMatrix = ndarray DEFAULT_ITERATIONS: dict[Shape, int] = defaultdict(lambda: 10_000) def run( prior: Prior, obs: Obs, obs_noise: float, min_grid_size: int = 32, iterations: dict[Shape, int] = DEFAULT_ITERATIONS, **config_kwargs, ) -> list[tuple[Shape, int, Marginals]]: target_shape = prior.shape levels = build_levels(min_grid_size, target_shape) iterations = DEFAULT_ITERATIONS | iterations obs_grid = fill_obs_matrix(obs, target_shape) marginals = [] iterations_list = [] for level_i, level_shape in enumerate(levels): print(f"Running Message Passing for level {level_i}") if level_i == len(levels) - 1: level_prior = prior else: # TODO: Ensure the level prior has the same parameters as the target prior. # At the moment this doesn't matter because the only parameter is the grid # shape. level_prior = gp.get_prior(level_shape) level_obs = pull_obs_from_target(obs_grid, target_shape, level_shape) posterior = gp.get_posterior(level_prior, level_obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision ) config = Config( graph=graph, Gamma_diagonal=jnp.array(Gamma_diagonal), h=jnp.array(posterior.shift), **config_kwargs, ) if level_i == 0: initial_edges = message_passing.get_initial_edges(graph) else: initial_edges = _expand_edges( edges, Shape(levels[level_i - 1].width - 2, levels[level_i - 1].height - 2), level_prior.interior_shape, ) edges, level_marginals = message_passing.iterate( config, initial_edges, n_iterations=iterations[level_shape], early_stopping_threshold=0.0001, ) level_marginals = Marginals( mean=level_marginals.mean.reshape(level_prior.interior_shape), std=level_marginals.std.reshape(level_prior.interior_shape), ) marginals.append(level_marginals) iterations_list.append(iterations[level_shape]) return list(zip(levels, iterations_list, marginals)) def run_on_sphere( prior: Prior, obs: Obs, obs_noise: float, lat: ndarray, lon: ndarray, ratios: list = [], levels: list = [], iterations: dict[Shape, int] = DEFAULT_ITERATIONS, **config_kwargs, ) -> list[tuple[Shape, int, Marginals]]: target_shape = prior.shape iterations = DEFAULT_ITERATIONS | iterations obs_grid = fill_obs_matrix(obs, target_shape) marginals = [] iterations_list = [] for level_i, level_shape in enumerate(levels): print(f"Running Message Passing for level {level_i}") if level_i == len(levels) - 1: level_prior = prior else: # TODO: Ensure the level prior has the same parameters as the target prior. # At the moment this doesn't matter because the only parameter is the grid # shape. level_prior = gp.get_prior_sphere(level_shape, lon[::ratios[level_i]], lat[::ratios[level_i]]) level_obs = pull_obs_from_target(obs_grid, target_shape, level_shape) posterior = gp.get_posterior(level_prior, level_obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision ) config = Config( graph=graph, Gamma_diagonal=jnp.array(Gamma_diagonal), h=jnp.array(posterior.shift), **config_kwargs, ) if level_i == 0: initial_edges = message_passing.get_initial_edges(graph) else: initial_edges = _expand_edges( edges, Shape(levels[level_i - 1].width - 2, levels[level_i - 1].height - 2), level_prior.interior_shape, ) edges, level_marginals = message_passing.iterate( config, initial_edges, n_iterations=iterations[level_shape], early_stopping_threshold=0.0001, ) level_marginals = Marginals( mean=level_marginals.mean.reshape(level_prior.interior_shape), std=level_marginals.std.reshape(level_prior.interior_shape), ) marginals.append(level_marginals) iterations_list.append(iterations[level_shape]) return list(zip(levels, iterations_list, marginals)) def _expand_edges(edges: Edges, previous_shape: Shape, next_shape: Shape) -> Edges: """Gets the initial edges for the next multigrid level.""" return _expand_edges_jittable( a_grid=edges.a.reshape( previous_shape.width, previous_shape.height, edges.a.shape[1] ), b_grid=edges.b.reshape( previous_shape.width, previous_shape.height, edges.b.shape[1] ), width_ratio=next_shape.width // previous_shape.width, height_ratio=next_shape.height // previous_shape.height, ) @partial(jit, static_argnames=("width_ratio", "height_ratio")) def _expand_edges_jittable( a_grid: Array, b_grid: Array, width_ratio: int, height_ratio: int ) -> Edges: a = width_ratio * height_ratio * a_grid a = jnp.repeat(a, width_ratio, axis=0) a = jnp.repeat(a, height_ratio, axis=1) a = jnp.pad(a, ((1, 1), (1, 1), (0, 0))) a = jnp.nan_to_num(a) a = a.reshape(-1, a_grid.shape[2]) b = width_ratio * height_ratio * b_grid b = jnp.repeat(b, width_ratio, axis=0) b = jnp.repeat(b, height_ratio, axis=1) b = jnp.pad(b, ((1, 1), (1, 1), (0, 0))) b = jnp.nan_to_num(b) b = b.reshape(-1, b_grid.shape[2]) return Edges(a=a, b=b) def build_levels(min_size: int, target_shape: Shape) -> list[Shape]: assert target_shape.width % 2 == 0 and target_shape.height % 2 == 0 assert min_size >= 2 and min_size % 2 == 0 assert target_shape.width % min_size == 0 and target_shape.height % min_size == 0 # We reduce the largest dimension to the minimum size, thus the other dimension # might end up smaller. if target_shape.width >= target_shape.height: min_shape = Shape( width=min_size, height=round(target_shape.height / target_shape.width * min_size), ) else: min_shape = Shape( width=round(target_shape.width / target_shape.height * min_size), height=min_size, ) levels = [min_shape] while levels[-1] != target_shape: levels.append(Shape(levels[-1].width * 2, levels[-1].height * 2)) return levels def fill_obs_matrix(obs: Obs, target_shape: Shape) -> ObsMatrix: """ Fill a matrix with the observation values and locations. """ obs_matrix = np.zeros(target_shape) for idx, val in obs: obs_matrix[idx[0], idx[1]] = val return obs_matrix def pull_obs_from_target( obs_grid: ndarray, target_shape: Shape, level_shape: Shape ) -> Obs: """ Read in additional observations on the target grid which align with the current grid. """ assert target_shape.width % level_shape.width == 0 assert target_shape.height % level_shape.height == 0 width_ratio = target_shape.width // level_shape.width height_ratio = target_shape.height // level_shape.height collocated_points_mask = np.zeros(target_shape) # Select which points on the fine grid are also on the coarse grid. collocated_points_mask[::width_ratio, ::height_ratio] = 1 # Select the observations at the collocated points. obs_on_grid = collocated_points_mask * obs_grid # Adjust the (x,y) coordinates of the remaining points on the target grid so they # are in the coordinate system of the coarse grid. regridded_observations = [ ((xy[0] // width_ratio, xy[1] // height_ratio), obs_on_grid[xy[0], xy[1]]) for xy in np.argwhere(obs_on_grid != 0) ] return regridded_observations