da-message-passing / src / damp / multigrid.py
multigrid.py
Raw
from collections import defaultdict
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import Array, jit
from numpy import ndarray

from damp import gp, message_passing
from damp.gp import Obs, Prior, Shape
from damp.graph import graph_and_diagonal_from_precision_matrix
from damp.message_passing import Config, Edges, Marginals

ObsMatrix = ndarray

DEFAULT_ITERATIONS: dict[Shape, int] = defaultdict(lambda: 10_000)


def run(
    prior: Prior,
    obs: Obs,
    obs_noise: float,
    min_grid_size: int = 32,
    iterations: dict[Shape, int] = DEFAULT_ITERATIONS,
    **config_kwargs,
) -> list[tuple[Shape, int, Marginals]]:
    target_shape = prior.shape
    levels = build_levels(min_grid_size, target_shape)
    iterations = DEFAULT_ITERATIONS | iterations

    obs_grid = fill_obs_matrix(obs, target_shape)

    marginals = []
    iterations_list = []

    for level_i, level_shape in enumerate(levels):
        print(f"Running Message Passing for level {level_i}")

        if level_i == len(levels) - 1:
            level_prior = prior
        else:
            # TODO: Ensure the level prior has the same parameters as the target prior.
            # At the moment this doesn't matter because the only parameter is the grid
            # shape.
            level_prior = gp.get_prior(level_shape)
        level_obs = pull_obs_from_target(obs_grid, target_shape, level_shape)
        posterior = gp.get_posterior(level_prior, level_obs, obs_noise)

        graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
            posterior.precision
        )
        config = Config(
            graph=graph,
            Gamma_diagonal=jnp.array(Gamma_diagonal),
            h=jnp.array(posterior.shift),
            **config_kwargs,
        )

        if level_i == 0:
            initial_edges = message_passing.get_initial_edges(graph)
        else:
            initial_edges = _expand_edges(
                edges,
                Shape(levels[level_i - 1].width - 2, levels[level_i - 1].height - 2),
                level_prior.interior_shape,
            )

        edges, level_marginals = message_passing.iterate(
            config,
            initial_edges,
            n_iterations=iterations[level_shape],
            early_stopping_threshold=0.0001,
        )
        level_marginals = Marginals(
            mean=level_marginals.mean.reshape(level_prior.interior_shape),
            std=level_marginals.std.reshape(level_prior.interior_shape),
        )
        marginals.append(level_marginals)
        iterations_list.append(iterations[level_shape])

    return list(zip(levels, iterations_list, marginals))

def run_on_sphere(
    prior: Prior,
    obs: Obs,
    obs_noise: float,
    lat: ndarray,
    lon: ndarray,
    ratios: list = [],
    levels: list = [],
    iterations: dict[Shape, int] = DEFAULT_ITERATIONS,
    **config_kwargs,
) -> list[tuple[Shape, int, Marginals]]:
    target_shape = prior.shape

    iterations = DEFAULT_ITERATIONS | iterations

    obs_grid = fill_obs_matrix(obs, target_shape)

    marginals = []
    iterations_list = []

    for level_i, level_shape in enumerate(levels):
        print(f"Running Message Passing for level {level_i}")

        if level_i == len(levels) - 1:
            level_prior = prior
        else:
            # TODO: Ensure the level prior has the same parameters as the target prior.
            # At the moment this doesn't matter because the only parameter is the grid
            # shape.
            level_prior = gp.get_prior_sphere(level_shape, lon[::ratios[level_i]], lat[::ratios[level_i]])
        level_obs = pull_obs_from_target(obs_grid, target_shape, level_shape)
        posterior = gp.get_posterior(level_prior, level_obs, obs_noise)

        graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
            posterior.precision
        )
        config = Config(
            graph=graph,
            Gamma_diagonal=jnp.array(Gamma_diagonal),
            h=jnp.array(posterior.shift),
            **config_kwargs,
        )

        if level_i == 0:
            initial_edges = message_passing.get_initial_edges(graph)
        else:
            initial_edges = _expand_edges(
                edges,
                Shape(levels[level_i - 1].width - 2, levels[level_i - 1].height - 2),
                level_prior.interior_shape,
            )

        edges, level_marginals = message_passing.iterate(
            config,
            initial_edges,
            n_iterations=iterations[level_shape],
            early_stopping_threshold=0.0001,
        )
        level_marginals = Marginals(
            mean=level_marginals.mean.reshape(level_prior.interior_shape),
            std=level_marginals.std.reshape(level_prior.interior_shape),
        )
        marginals.append(level_marginals)
        iterations_list.append(iterations[level_shape])

    return list(zip(levels, iterations_list, marginals))

def _expand_edges(edges: Edges, previous_shape: Shape, next_shape: Shape) -> Edges:
    """Gets the initial edges for the next multigrid level."""
    return _expand_edges_jittable(
        a_grid=edges.a.reshape(
            previous_shape.width, previous_shape.height, edges.a.shape[1]
        ),
        b_grid=edges.b.reshape(
            previous_shape.width, previous_shape.height, edges.b.shape[1]
        ),
        width_ratio=next_shape.width // previous_shape.width,
        height_ratio=next_shape.height // previous_shape.height,
    )


@partial(jit, static_argnames=("width_ratio", "height_ratio"))
def _expand_edges_jittable(
    a_grid: Array, b_grid: Array, width_ratio: int, height_ratio: int
) -> Edges:
    a = width_ratio * height_ratio * a_grid
    a = jnp.repeat(a, width_ratio, axis=0)
    a = jnp.repeat(a, height_ratio, axis=1)
    a = jnp.pad(a, ((1, 1), (1, 1), (0, 0)))
    a = jnp.nan_to_num(a)
    a = a.reshape(-1, a_grid.shape[2])

    b = width_ratio * height_ratio * b_grid
    b = jnp.repeat(b, width_ratio, axis=0)
    b = jnp.repeat(b, height_ratio, axis=1)
    b = jnp.pad(b, ((1, 1), (1, 1), (0, 0)))
    b = jnp.nan_to_num(b)
    b = b.reshape(-1, b_grid.shape[2])

    return Edges(a=a, b=b)


def build_levels(min_size: int, target_shape: Shape) -> list[Shape]:
    assert target_shape.width % 2 == 0 and target_shape.height % 2 == 0
    assert min_size >= 2 and min_size % 2 == 0
    assert target_shape.width % min_size == 0 and target_shape.height % min_size == 0

    # We reduce the largest dimension to the minimum size, thus the other dimension
    # might end up smaller.
    if target_shape.width >= target_shape.height:
        min_shape = Shape(
            width=min_size,
            height=round(target_shape.height / target_shape.width * min_size),
        )
    else:
        min_shape = Shape(
            width=round(target_shape.width / target_shape.height * min_size),
            height=min_size,
        )

    levels = [min_shape]
    while levels[-1] != target_shape:
        levels.append(Shape(levels[-1].width * 2, levels[-1].height * 2))
    return levels


def fill_obs_matrix(obs: Obs, target_shape: Shape) -> ObsMatrix:
    """
    Fill a matrix with the observation values and locations.
    """
    obs_matrix = np.zeros(target_shape)
    for idx, val in obs:
        obs_matrix[idx[0], idx[1]] = val
    return obs_matrix


def pull_obs_from_target(
    obs_grid: ndarray, target_shape: Shape, level_shape: Shape
) -> Obs:
    """
    Read in additional observations on the target grid which align with the current grid.
    """
    assert target_shape.width % level_shape.width == 0
    assert target_shape.height % level_shape.height == 0
    width_ratio = target_shape.width // level_shape.width
    height_ratio = target_shape.height // level_shape.height

    collocated_points_mask = np.zeros(target_shape)
    # Select which points on the fine grid are also on the coarse grid.
    collocated_points_mask[::width_ratio, ::height_ratio] = 1
    # Select the observations at the collocated points.
    obs_on_grid = collocated_points_mask * obs_grid
    # Adjust the (x,y) coordinates of the remaining points on the target grid so they
    # are in the coordinate system of the coarse grid.
    regridded_observations = [
        ((xy[0] // width_ratio, xy[1] // height_ratio), obs_on_grid[xy[0], xy[1]])
        for xy in np.argwhere(obs_on_grid != 0)
    ]
    return regridded_observations