da-message-passing / src / damp / gp.py
gp.py
Raw
from dataclasses import dataclass
from typing import NamedTuple

import numpy as np
from findiff import PDE, BoundaryConditions, FinDiff, Identity, Coefficient
from numpy import ndarray
from numpy.random import Generator
from scipy.sparse import diags, spmatrix
from scipy.special import gamma

Obs = list[tuple[tuple[int, int], ndarray]]


class Shape(NamedTuple):
    width: int
    height: int

    def flatten(self) -> int:
        return self.width * self.height


@dataclass(frozen=True)
class Prior:
    d: int
    ls: float
    nu: float
    amp: float
    shape: Shape
    precision: spmatrix
    precision_decomposed: spmatrix
    grid_idxs: ndarray
    interior_idxs: ndarray

    @property
    def interior_shape(self) -> Shape:
        return Shape(self.shape.width - 2, self.shape.height - 2)

    @property
    def name(self) -> str:
        return f"size_{self.shape.width}_{self.shape.height}"


@dataclass(frozen=True)
class Posterior:
    shift: ndarray
    precision: spmatrix
    obs_noise: float
    obs_location_mask: spmatrix


def get_prior(shape: Shape) -> Prior:
    d = 2
    ls = 0.15
    nu = 1
    amp = 1.1
    x = np.linspace(0, 1, shape.width)
    y = np.linspace(0, 1, shape.height)

    dx = x[1] - x[0]
    dy = y[1] - y[0]

    # Set LHS
    diff_op = _kappa(nu, ls) ** 2 * Identity() - FinDiff(0, dx, 2) - FinDiff(1, dy, 2)

    # Construct matern-1 precision matrix
    mat = _operator_to_matrix(diff_op, shape)
    precision_decomposed = np.sqrt(dx * dy / (_q(d, nu, ls) * amp**2)) * mat
    precision = precision_decomposed.T @ precision_decomposed

    return Prior(
        d,
        ls,
        nu,
        amp,
        shape,
        precision,
        precision_decomposed,
        interior_idxs=_get_interior_indices(shape),
        grid_idxs=_get_domain_indices(shape),
    )


def get_prior_sphere(shape: Shape, lon: ndarray, lat: ndarray) -> Prior:
    d = 2
    ls = 0.2
    nu = 1
    amp = 1.9
    # Shifting the undefined region to the north pole.
    # lat now goes from 0 -> 180 deg
    lat = lat + 90
    phi = np.radians(lat)
    theta = np.radians(lon)
    dtheta, dphi = theta[1] - theta[0], phi[1] - phi[0]

    Theta, Phi = np.meshgrid(theta, phi)
    # Set LHS
    diff_op = (
        _kappa(nu, ls) ** 2 * Identity()
        - Coefficient(1 / np.tan(Phi)) * FinDiff(0, dphi)
        - FinDiff(0, dphi, 2)
        - Coefficient((1 / np.sin(Phi)) ** 2) * FinDiff(1, dtheta, 2)
    )
    # Construct matern-1 precision matrix
    mat = _operator_to_matrix(diff_op, shape)
    # Extract the interior Phi values
    # Used to scale the precision
    Phi_interior = Phi[1:-1,1:-1]
    Phi_interior = Phi_interior.flatten()
    Phi_interior = diags(Phi_interior)

    precision_decomposed = np.sqrt((np.sin(Phi_interior)*dtheta * dphi) / (_q(d, nu, ls) * amp**2)) * mat
    precision = precision_decomposed.T @ precision_decomposed
    return Prior(
        d,
        ls,
        nu,
        amp,
        shape,
        precision,
        precision_decomposed,
        interior_idxs=_get_interior_indices(shape),
        grid_idxs=_get_domain_indices(shape),
    )


def sample_prior(rng: Generator, prior: Prior) -> ndarray:
    x = np.linspace(0, 1, prior.shape.width)
    y = np.linspace(0, 1, prior.shape.height)

    dx = x[1] - x[0]
    dy = y[1] - y[0]

    # Set LHS
    kappa = _kappa(prior.nu, prior.ls)
    diff_op = kappa**2 * Identity() - FinDiff(0, dx, 2) - FinDiff(1, dy, 2)

    # Set RHS
    const = (dx * dy) ** (-0.5) * np.sqrt(_q(prior.d, prior.nu, prior.ls)) * prior.amp
    W = const * rng.normal(size=prior.shape)

    # Set boundary conditions (zero-Dirichlet)
    bc = BoundaryConditions(prior.shape)
    bc[0, :] = 0
    bc[-1, :] = 0
    bc[:, 0] = 0
    bc[:, -1] = 0

    # Solve PDE
    pde = PDE(diff_op, W, bc)
    return pde.solve()


def get_posterior(
    prior: Prior, observations: Obs, obs_noise: float = 1e-3
) -> Posterior:
    shape = prior.grid_idxs.shape

    N = np.prod(shape)
    mask = np.zeros(N)
    for idx, _ in observations:
        mask[prior.grid_idxs[idx]] = 1
    posterior_precision = prior.precision + obs_noise ** (-2) * diags(
        mask[prior.interior_idxs]
    )

    posterior_shift = np.zeros(np.prod(shape))
    for idx, observation in observations:
        posterior_shift[prior.grid_idxs[idx]] = observation / obs_noise**2
    posterior_shift = posterior_shift[prior.interior_idxs]

    return Posterior(
        posterior_shift,
        posterior_precision,
        obs_noise,
        obs_location_mask=mask[prior.interior_idxs],
    )


def _get_domain_indices(shape):
    siz = np.prod(shape)
    full_indices = np.array(list(range(siz))).reshape(shape)
    return full_indices


def _get_interior_indices(shape) -> ndarray:
    full_indices = _get_domain_indices(shape)
    interior_slice = tuple(slice(1, -1) for _ in range(len(shape)))
    interior_indices = full_indices[interior_slice].flatten()
    return interior_indices


def _operator_to_matrix(diff_op, shape):
    """
    Convert a findiff operator into a precision matrix
    """
    mat = diff_op.matrix(shape)
    interior_idxs = _get_interior_indices(shape)
    mat = mat[interior_idxs]
    mat = mat[:, interior_idxs]
    return mat


def _kappa(nu: float, ls: float) -> float:
    return np.sqrt(2 * nu) / ls


def _q(d: int, nu: float, ls: float) -> float:
    return (
        (4 * np.pi) ** (d / 2) * _kappa(nu, ls) ** (2 * nu) * gamma(nu + d / 2)
    ) / gamma(nu)


def choose_observations(
    rng: Generator, n_obs: int, ground_truth: ndarray, obs_noise: float
) -> Obs:
    x_idxs = np.arange(ground_truth.shape[0])
    y_idxs = np.arange(ground_truth.shape[1])
    X_idxs, Y_idxs = np.meshgrid(x_idxs[1:-1], y_idxs[1:-1], indexing="ij")
    all_idxs = np.stack([X_idxs.flatten(), Y_idxs.flatten()], axis=1)
    idxs = rng.choice(all_idxs, n_obs, replace=False)
    return [((x, y), ground_truth[(x, y)] + obs_noise * rng.normal()) for x, y in idxs]