da-message-passing / tests / test_message_passing.py
test_message_passing.py
Raw
import chex
import jax.numpy as jnp
from jax import jit
from numpy.random import default_rng

from damp import gp, message_passing, so_impl
from damp.gp import Shape
from damp.graph import Graph, graph_and_diagonal_from_precision_matrix
from damp.message_passing import Config, Edges


def test__send_message__equal_to_reference_impl_over_entire_graph() -> None:
    config, edges, ref_graph = set_up_graph()

    send_message = jit(message_passing.send_message)
    set_edge = jit(config.graph.set_edge)

    for i in ref_graph.nodes:
        for j in ref_graph.neighbors(i):
            m = send_message(config, edges, i, j)
            edges = Edges(
                a=set_edge(edges.a, i, j, m.a), b=set_edge(edges.b, i, j, m.b)
            )

            ref_graph.send_message(i, j)
            expected = ref_graph.edges[(i, j)]["message"]

            assert jnp.allclose(m.a, expected.precision, atol=1e-9)
            assert jnp.allclose(m.b, expected.shift, atol=1e-9)


def test__send_message__i_outside_grid__returns_nan() -> None:
    config, initial_edges, _ = set_up_graph()
    send_message = message_passing.send_message

    message = send_message(config, initial_edges, i=to_idx(-1, 0), j=to_idx(0, 0))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)
    message = send_message(config, initial_edges, i=to_idx(0, -1), j=to_idx(0, 0))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)
    message = send_message(config, initial_edges, i=to_idx(10, 10), j=to_idx(9, 9))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)


def test__send_message__j_outside_grid__returns_nan() -> None:
    config, initial_edges, _ = set_up_graph()
    send_message = message_passing.send_message

    message = send_message(config, initial_edges, i=to_idx(0, 0), j=to_idx(-1, 0))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)
    message = send_message(config, initial_edges, i=to_idx(0, 0), j=to_idx(12, 0))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)


def test__send_message__i_j_not_neighbours__returns_nan() -> None:
    config, initial_edges, _ = set_up_graph()
    send_message = message_passing.send_message

    message = send_message(config, initial_edges, i=to_idx(0, 0), j=to_idx(0, 3))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)
    message = send_message(config, initial_edges, i=to_idx(0, 0), j=to_idx(2, 2))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)
    message = send_message(config, initial_edges, i=to_idx(4, 4), j=to_idx(6, 6))
    assert jnp.isnan(message.a)
    assert jnp.isnan(message.b)


def test__send_message__multiple_cs__messages_depends_on_c() -> None:
    config, initial_edges, _ = set_up_graph()
    send_message = message_passing.send_message

    c_graph = config.graph.duplicate_with_constant_weights(weight=jnp.array(1.0))
    c_graph.weights = c_graph.weights.at[45, 0].set(2.0)
    config = config.replace(c=c_graph)

    message_1 = send_message(config, initial_edges, i=to_idx(5, 5), j=to_idx(5, 6))
    message_2 = send_message(config, initial_edges, i=to_idx(5, 5), j=to_idx(6, 5))
    message_3 = send_message(config, initial_edges, i=to_idx(5, 5), j=to_idx(5, 3))

    assert jnp.allclose(message_1.b, message_2.b)
    assert jnp.allclose(message_1.a, message_2.a)
    assert not jnp.allclose(message_3.a, message_2.a)
    assert not jnp.allclose(message_3.a, message_2.a)


def test__send_all_messages_parallel__result_has_correct_shape() -> None:
    config, initial_edges, _ = set_up_graph()
    send_all_messages_parallel = message_passing.send_all_messages_parallel

    next_edges = send_all_messages_parallel(config, initial_edges)

    assert next_edges.a.shape == initial_edges.a.shape
    assert next_edges.b.shape == initial_edges.b.shape


def test__iterate__parallel__result_has_correct_shape() -> None:
    config, initial_edges, _ = set_up_graph()
    n_points = initial_edges.a.shape[0]

    _, marginals1 = message_passing.iterate(config, initial_edges, n_iterations=0)
    assert marginals1.mean.shape == (n_points,)
    assert marginals1.std.shape == (n_points,)

    _, marginals2 = message_passing.iterate(config, initial_edges, n_iterations=1)
    assert marginals2.mean.shape == (n_points,)
    assert marginals2.std.shape == (n_points,)


def test__iterate__multiple_cs__result_has_correct_shape() -> None:
    config, initial_edges, _ = set_up_graph()
    n_points = initial_edges.a.shape[0]

    c = config.graph.weights
    c = c.at[config.graph.connectivity != -1].set(1.0)
    c = c.at[1, 2].set(2.0)
    c = c.at[2, 1].set(2.0)
    config = config.replace(c=Graph(config.graph.connectivity, weights=c))

    _, marginals1 = message_passing.iterate(config, initial_edges, n_iterations=0)
    assert marginals1.mean.shape == (n_points,)
    assert marginals1.std.shape == (n_points,)

    _, marginals2 = message_passing.iterate(config, initial_edges, n_iterations=1)
    assert marginals2.mean.shape == (n_points,)
    assert marginals2.std.shape == (n_points,)


def set_up_graph(
    grid_size: int = 10,
) -> tuple[message_passing.Config, Edges, so_impl.FactorGraphFromArray]:
    numpy_rng = default_rng(seed=1124)
    c = -1.0

    prior = gp.get_prior(Shape(grid_size, grid_size))
    ground_truth = gp.sample_prior(numpy_rng, prior)
    obs = gp.choose_observations(
        numpy_rng, n_obs=5, ground_truth=ground_truth, obs_noise=1e-3
    )
    posterior = gp.get_posterior(prior, obs)

    graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
        posterior.precision
    )

    config = Config(
        graph=graph,
        c=c,
        Gamma_diagonal=jnp.array(Gamma_diagonal),
        h=jnp.array(posterior.shift),
    )
    initial_edges = message_passing.get_initial_edges(graph)
    ref_factor_graph = so_impl.FactorGraphFromArray(
        posterior.precision, posterior.shift, c=c
    )

    return config, initial_edges, ref_factor_graph


def to_idx(x: int, y: int) -> int:
    interior_size = 8
    return y * interior_size + x


def teardown_function() -> None:
    chex.block_until_chexify_assertions_complete()