import jax.numpy as jnp from jax.typing import ArrayLike from damp.jax_utils import Array, jit @jit def rmse(mean: ArrayLike, ground_truth: ArrayLike) -> Array: gt_interior = ground_truth[1:-1, 1:-1] assert mean.shape == gt_interior.shape return jnp.sqrt(jnp.mean((mean - gt_interior) ** 2))