da-message-passing / tests / test_multigrid.py
test_multigrid.py
Raw
import numpy as np
import pytest

from damp import multigrid
from damp.gp import Shape


def test__fill_obs_matrix():
    obs = [
        ((0, 0), 1),
        ((1, 1), 2),
        ((2, 1), -1),
        ((2, 2), 3),
    ]
    target_shape = Shape(3, 3)
    expected_result = np.array([[1, 0, 0], [0, 2, 0], [0, -1, 3]])

    result = multigrid.fill_obs_matrix(obs, target_shape)
    assert np.array_equal(result, expected_result)


def test__fill_obs_matrix__with_empty_obs():
    # Empty observation
    obs = []
    target_shape = Shape(3, 3)
    expected_result = np.zeros(target_shape)

    result = multigrid.fill_obs_matrix(obs, target_shape)
    assert np.array_equal(result, expected_result)


def test__fill_obs_matrix__with_out_of_bounds_index():
    obs = [
        ((0, 0), 1),
        ((1, 1), 2),
        ((2, 1), -1),
        ((2, 2), 3),
        ((3, 1), 4),
    ]
    target_shape = Shape(3, 3)

    with pytest.raises(IndexError):
        multigrid.fill_obs_matrix(obs, target_shape)


def test__pull_obs_from_target__obs_have_correct_value() -> None:
    target_shape = Shape(8, 8)
    coarse_shape = Shape(2, 2)
    obs_grid = np.arange(1, target_shape.flatten() + 1)
    obs_grid = obs_grid.reshape(target_shape)
    obs = multigrid.pull_obs_from_target(obs_grid, target_shape, coarse_shape)
    assert len(obs) == coarse_shape.flatten()
    assert obs[0][1] == 1.0
    assert obs[1][1] == 5.0
    assert obs[2][1] == 33.0
    assert obs[3][1] == 37.0


def test__pull_obs_from_target__not_divisible_raises():
    # The target shape is not a multiple of the coarse shape.
    target_shape = Shape(3, 3)
    coarse_shape = Shape(2, 2)

    obs_grid = np.arange(1, target_shape.flatten() + 1)
    obs_grid = obs_grid.reshape(target_shape)

    with pytest.raises(AssertionError):
        multigrid.pull_obs_from_target(obs_grid, target_shape, coarse_shape)