da-message-passing / tests / test_jax_utils.py
test_jax_utils.py
Raw
import chex
import jax.numpy as jnp
import jax.random
from jax import Array, vmap
from jax.random import PRNGKey

from damp.jax_utils import batch_vmap, tree_concatenate


@chex.dataclass
class ExamplePytree:
    a: Array
    b: Array


def test__tree_concatenate__correctly_concatenates_leaves():
    one = ExamplePytree(a=jnp.array([1]), b=jnp.array([2]))
    two = ExamplePytree(a=jnp.array([10]), b=jnp.array([20]))

    cat = tree_concatenate([one, two])

    assert jnp.allclose(cat.a, jnp.array([1, 10]))
    assert jnp.allclose(cat.b, jnp.array([2, 20]))


def test__batch_vmap__result_same_as_single_vmap():
    rng = PRNGKey(seed=1231245)
    xs = jax.random.normal(rng, shape=(100, 20))

    def f(xs: Array) -> tuple[Array, Array]:
        return xs + 3.0, xs * 2.0

    expected = vmap(f)(xs)
    batched = batch_vmap(f, xs, batch_size=3)
    batched_with_progress = batch_vmap(f, xs, batch_size=3, progress=True)

    assert jnp.allclose(batched[0], expected[0])
    assert jnp.allclose(batched[1], expected[1])
    assert jnp.allclose(batched_with_progress[0], expected[0])
    assert jnp.allclose(batched_with_progress[1], expected[1])