da-message-passing / src / damp / message_passing.py
message_passing.py
Raw
from __future__ import annotations

from typing import Optional, Union

import chex
import jax.numpy as jnp
from jax import Array, jit, vmap
from tqdm import tqdm

from damp.graph import Graph, Index


@chex.dataclass(frozen=True)
class Edges:
    a: Array
    b: Array


@chex.dataclass(frozen=True)
class Marginals:
    mean: Array
    std: Array


@chex.dataclass(frozen=True)
class Config:
    graph: Graph
    c: Union[float, Graph]
    Gamma_diagonal: Array
    h: Array
    lr: float = 1.0

    def __post_init__(self) -> None:
        assert self.Gamma_diagonal.ndim == 1
        assert self.h.ndim == 1

        # This assertion is disabled for now because it was making things slow.
        # if isinstance(self.c, Stencil):
        #     chex.assert_trees_all_close(self.graph.mask, self.c.mask)
        if isinstance(self.c, Array):
            assert self.c.shape == (), "If c isn't a graph must be a scalar"


def get_initial_edges(graph: Graph) -> Edges:
    return Edges(a=graph.init_edges(value=0.0), b=graph.init_edges(value=1e-8))


def send_message(cfg: Config, edges: Edges, i: Index, j: Index) -> Edges:
    c_graph = _get_c_graph(cfg)

    # These equations are taken from the paper
    # "Message-Passing Algorithms for Quadratic Minimization"; Ruozzi 2013
    a_ji = cfg.graph.get_edge(edges.a, j, i)
    b_ji = cfg.graph.get_edge(edges.b, j, i)

    A_ij = c_graph.weighted_sum_incoming_edges(edges.a, i)
    A_ij += cfg.Gamma_diagonal[i]
    A_ij -= a_ji

    B_ij = c_graph.weighted_sum_incoming_edges(edges.b, i)
    B_ij += cfg.h[i]
    B_ij -= b_ji

    Gamma_ij_over_c_ij = cfg.graph.get_weight(i, j) / c_graph.get_weight(i, j)
    a_ij = -(Gamma_ij_over_c_ij**2) / A_ij
    b_ij = -(B_ij * Gamma_ij_over_c_ij) / A_ij

    return Edges(a=a_ij, b=b_ij)


def _get_c_graph(cfg: Config) -> Graph:
    # We may either have a single shared c, or a separate c for every edge in the
    # stencil. If the former, convert it into the latter for convenience.
    if isinstance(cfg.c, float) or (isinstance(cfg.c, Array) and cfg.c.shape == ()):
        return cfg.graph.duplicate_with_constant_weights(jnp.array(cfg.c))
    else:
        return cfg.c


@jit
def send_all_messages_parallel(c: Config, edges: Edges) -> Edges:
    def message_neighbours(j: Index) -> Edges:
        idxs = c.graph.get_neighour_indices(j)
        messages = vmap(lambda i: send_message(c, edges, i, j))(idxs)
        return Edges(a=messages.a, b=messages.b)

    new_edges = vmap(message_neighbours)(jnp.arange(0, edges.a.shape[0]))
    return Edges(
        a=(1.0 - c.lr) * edges.a + c.lr * new_edges.a,
        b=(1.0 - c.lr) * edges.b + c.lr * new_edges.b,
    )


@jit
def _extract_marginals(c: Config, edges: Edges) -> Marginals:
    return vmap(lambda i: _extract_marginal(c, edges, i))(
        jnp.arange(0, edges.a.shape[0])
    )


def _extract_marginal(cfg: Config, edges: Edges, i: Index) -> Marginals:
    c = _get_c_graph(cfg)
    precision = cfg.Gamma_diagonal[i] + c.weighted_sum_incoming_edges(edges.a, i)
    shift = cfg.h[i] + c.weighted_sum_incoming_edges(edges.b, i)

    return Marginals(mean=shift / precision, std=jnp.sqrt(1 / precision))


def iterate(
    c: Config,
    initial_edges: Edges,
    n_iterations: int,
    progress_bar: bool = True,
    early_stopping_threshold: Optional[float] = None,
) -> tuple[Edges, Marginals]:
    history = iterate_with_history(
        c,
        initial_edges,
        n_iterations,
        save_every=-1,
        progress_bar=progress_bar,
        early_stopping_threshold=early_stopping_threshold,
    )
    _, final_edges, final_nodes = history[-1]
    return final_edges, final_nodes


def iterate_with_history(
    c: Config,
    initial_edges: Edges,
    n_iterations: int,
    save_every: int = 1,
    progress_bar: bool = True,
    early_stopping_threshold: Optional[float] = None,
) -> list[tuple[int, Edges, Marginals]]:
    edge_history = [(0, initial_edges)]
    edges = initial_edges
    initial_delta = None
    iterator = tqdm(range(n_iterations)) if progress_bar else range(n_iterations)
    for i in iterator:
        new_edges = send_all_messages_parallel(c, edges)

        # We only check for early stopping every now and then to avoid slowing down the
        # process too much.
        if early_stopping_threshold is not None and i % 50 == 0:
            delta_magnitude = _get_delta_magnitude(edges, new_edges)
            if initial_delta is None:
                initial_delta = delta_magnitude
            should_early_stop = bool(
                delta_magnitude < (initial_delta * early_stopping_threshold)
            )
        else:
            should_early_stop = False
        last_iteration = should_early_stop or (i + 1 == n_iterations)

        if (save_every > 0 and (i + 1) % save_every == 0) or last_iteration:
            edge_history.append((i + 1, new_edges))

        edges = new_edges

        if should_early_stop:
            break

    result = [(i, e, _extract_marginals(c, e)) for i, e in edge_history]
    chex.block_until_chexify_assertions_complete()
    return result


@jit
def _get_delta_magnitude(previous: Edges, new: Edges) -> Array:
    # There are nans in the graph data structure which do not correspond to actual
    # edges. Thus ignore these when taking the mean.
    return jnp.maximum(
        jnp.nanmean(jnp.abs(previous.a - new.a)).mean(),
        jnp.nanmean(jnp.abs(previous.b - new.b)).mean(),
    )