da-message-passing / src / damp / metrics.py
metrics.py
Raw
import jax.numpy as jnp
from jax.typing import ArrayLike

from damp.jax_utils import Array, jit


@jit
def rmse(mean: ArrayLike, ground_truth: ArrayLike) -> Array:
    gt_interior = ground_truth[1:-1, 1:-1]
    assert mean.shape == gt_interior.shape
    return jnp.sqrt(jnp.mean((mean - gt_interior) ** 2))