da-message-passing / src / damp / message_passing.py
from __future__ import annotations

from typing import Optional, Union

import chex
import jax.numpy as jnp
from jax import Array, jit, vmap
from tqdm import tqdm

from damp.graph import Graph, Index

class Edges:
    a: Array
    b: Array

class Marginals:
    mean: Array
    std: Array

class Config:
    graph: Graph
    c: Union[float, Graph]
    Gamma_diagonal: Array
    h: Array
    lr: float = 1.0

    def __post_init__(self) -> None:
        assert self.Gamma_diagonal.ndim == 1
        assert self.h.ndim == 1

        # This assertion is disabled for now because it was making things slow.
        # if isinstance(self.c, Stencil):
        #     chex.assert_trees_all_close(self.graph.mask, self.c.mask)
        if isinstance(self.c, Array):
            assert self.c.shape == (), "If c isn't a graph must be a scalar"

def get_initial_edges(graph: Graph) -> Edges:
    return Edges(a=graph.init_edges(value=0.0), b=graph.init_edges(value=1e-8))

def send_message(cfg: Config, edges: Edges, i: Index, j: Index) -> Edges:
    c_graph = _get_c_graph(cfg)

    # These equations are taken from the paper
    # "Message-Passing Algorithms for Quadratic Minimization"; Ruozzi 2013
    a_ji = cfg.graph.get_edge(edges.a, j, i)
    b_ji = cfg.graph.get_edge(edges.b, j, i)

    A_ij = c_graph.weighted_sum_incoming_edges(edges.a, i)
    A_ij += cfg.Gamma_diagonal[i]
    A_ij -= a_ji

    B_ij = c_graph.weighted_sum_incoming_edges(edges.b, i)
    B_ij += cfg.h[i]
    B_ij -= b_ji

    Gamma_ij_over_c_ij = cfg.graph.get_weight(i, j) / c_graph.get_weight(i, j)
    a_ij = -(Gamma_ij_over_c_ij**2) / A_ij
    b_ij = -(B_ij * Gamma_ij_over_c_ij) / A_ij

    return Edges(a=a_ij, b=b_ij)

def _get_c_graph(cfg: Config) -> Graph:
    # We may either have a single shared c, or a separate c for every edge in the
    # stencil. If the former, convert it into the latter for convenience.
    if isinstance(cfg.c, float) or (isinstance(cfg.c, Array) and cfg.c.shape == ()):
        return cfg.graph.duplicate_with_constant_weights(jnp.array(cfg.c))
        return cfg.c

def send_all_messages_parallel(c: Config, edges: Edges) -> Edges:
    def message_neighbours(j: Index) -> Edges:
        idxs = c.graph.get_neighour_indices(j)
        messages = vmap(lambda i: send_message(c, edges, i, j))(idxs)
        return Edges(a=messages.a, b=messages.b)

    new_edges = vmap(message_neighbours)(jnp.arange(0, edges.a.shape[0]))
    return Edges(
        a=(1.0 - c.lr) * edges.a + c.lr * new_edges.a,
        b=(1.0 - c.lr) * edges.b + c.lr * new_edges.b,

def _extract_marginals(c: Config, edges: Edges) -> Marginals:
    return vmap(lambda i: _extract_marginal(c, edges, i))(
        jnp.arange(0, edges.a.shape[0])

def _extract_marginal(cfg: Config, edges: Edges, i: Index) -> Marginals:
    c = _get_c_graph(cfg)
    precision = cfg.Gamma_diagonal[i] + c.weighted_sum_incoming_edges(edges.a, i)
    shift = cfg.h[i] + c.weighted_sum_incoming_edges(edges.b, i)

    return Marginals(mean=shift / precision, std=jnp.sqrt(1 / precision))

def iterate(
    c: Config,
    initial_edges: Edges,
    n_iterations: int,
    progress_bar: bool = True,
    early_stopping_threshold: Optional[float] = None,
) -> tuple[Edges, Marginals]:
    history = iterate_with_history(
    _, final_edges, final_nodes = history[-1]
    return final_edges, final_nodes

def iterate_with_history(
    c: Config,
    initial_edges: Edges,
    n_iterations: int,
    save_every: int = 1,
    progress_bar: bool = True,
    early_stopping_threshold: Optional[float] = None,
) -> list[tuple[int, Edges, Marginals]]:
    edge_history = [(0, initial_edges)]
    edges = initial_edges
    initial_delta = None
    iterator = tqdm(range(n_iterations)) if progress_bar else range(n_iterations)
    for i in iterator:
        new_edges = send_all_messages_parallel(c, edges)

        # We only check for early stopping every now and then to avoid slowing down the
        # process too much.
        if early_stopping_threshold is not None and i % 50 == 0:
            delta_magnitude = _get_delta_magnitude(edges, new_edges)
            if initial_delta is None:
                initial_delta = delta_magnitude
            should_early_stop = bool(
                delta_magnitude < (initial_delta * early_stopping_threshold)
            should_early_stop = False
        last_iteration = should_early_stop or (i + 1 == n_iterations)

        if (save_every > 0 and (i + 1) % save_every == 0) or last_iteration:
            edge_history.append((i + 1, new_edges))

        edges = new_edges

        if should_early_stop:

    result = [(i, e, _extract_marginals(c, e)) for i, e in edge_history]
    return result

def _get_delta_magnitude(previous: Edges, new: Edges) -> Array:
    # There are nans in the graph data structure which do not correspond to actual
    # edges. Thus ignore these when taking the mean.
    return jnp.maximum(
        jnp.nanmean(jnp.abs(previous.a - new.a)).mean(),
        jnp.nanmean(jnp.abs(previous.b - new.b)).mean(),