da-message-passing / src / experiments_archive / c_search.py
c_search.py
Raw
from typing import Callable, Literal, ParamSpec

import chex
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import skopt
from jax import Array
from jaxopt import LBFGS
from numpy import ndarray
from numpy.random import default_rng
from skopt.plots import plot_convergence

from damp import gp, ground_truth_cache, inla_bridge, message_passing
from damp.gp import Shape
from damp.graph import ConstantStencilGraph, graph_and_diagonal_from_precision_matrix
from damp.jax_utils import batch_vmap
from damp.message_passing import Config, Edges, Marginals


def main() -> None:
    # Support having a c for each edge
    # See if we get any gradient signal for the cs
    # Or maybe: see if we get any gradient signal for the single c first...
    numpy_rng = default_rng(seed=1124)

    prior = gp.get_prior(Shape(10, 10))
    ground_truth = next(ground_truth_cache.load_or_gen(prior))
    obs_noise = 1e-3
    obs = gp.choose_observations(
        numpy_rng, n_obs=150, ground_truth=ground_truth, obs_noise=obs_noise
    )

    mean_targets_numpy, std_targets_numpy = inla_bridge.run(prior, obs)
    mean_targets = jnp.array(mean_targets_numpy)
    std_targets = jnp.array(std_targets_numpy)

    posterior = gp.get_posterior(prior, obs, obs_noise)
    graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix(
        posterior.precision, prior.interior_shape
    )
    initial_edges = message_passing.get_initial_edges(graph, prior.interior_shape)
    Gamma_diagonal = jnp.array(Gamma_diagonal).reshape(prior.interior_shape)
    h = jnp.array(posterior.shift).reshape(prior.interior_shape)

    optimal_cs = _grid(
        mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges
    )
    # optimal_cs = _lbfgs(
    #     mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges
    # )
    # optimal_cs = _bayesopt(
    #     mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges
    # )
    print("Optimal:", ", ".join(f"{x:.10f}" for x in optimal_cs))

    pred_marginals = _run_mp(optimal_cs, graph, Gamma_diagonal, h, initial_edges)

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

    mean_min, mean_max = jnp.min(mean_targets), jnp.max(mean_targets)
    std_min, std_max = jnp.min(std_targets), jnp.max(std_targets)
    axes[0, 0].imshow(mean_targets)
    axes[1, 0].imshow(pred_marginals.mean, vmin=mean_min, vmax=mean_max)
    axes[0, 1].imshow(std_targets)
    axes[1, 1].imshow(pred_marginals.std, vmin=std_min, vmax=std_max)
    obs_xs, obs_ys = zip(*[(x - 1, y - 1) for (x, y), val in obs])
    axes[0, 1].scatter(obs_ys, obs_xs, color="red")
    axes[1, 1].scatter(obs_ys, obs_xs, color="red")
    plt.savefig("plots/c_search.png")
    plt.close()

    chex.block_until_chexify_assertions_complete()


def _bayesopt(
    mean_targets: Array,
    std_targets: Array,
    graph: ConstantStencilGraph,
    Gamma_diagonal: Array,
    h: Array,
    initial_edges: Edges,
) -> Array:
    def _obj(c: list[float]) -> ndarray:
        # c = jnp.array(ps)
        loss = _objective(
            jnp.array(c),
            mean_targets,
            std_targets,
            graph,
            Gamma_diagonal,
            h,
            initial_edges,
        )
        if jnp.isnan(loss) or jnp.isinf(loss):
            return 100_000.0
        else:
            return loss.item()

    result = skopt.gp_minimize(
        _obj,
        dimensions=[(9.0, 10.0), (1.5, 2.2), (4.0, 5.0)],
        n_jobs=6,
        verbose=True,
        n_initial_points=100,
        n_calls=200,
    )
    assert result is not None

    plot_convergence(result)
    plt.savefig("plots/c_bo_convergence.png")
    plt.close()

    return jnp.array(result.x)


def _lbfgs(
    mean_targets: Array,
    std_targets: Array,
    graph: ConstantStencilGraph,
    Gamma_diagonal: Array,
    h: Array,
    initial_edges: Edges,
) -> Array:
    obj = _value_and_jacfwd(_objective)
    opt = LBFGS(obj, value_and_grad=True, verbose=False)
    v, g = obj(
        jnp.array([9.2653055191, 1.9183673859, 4.4897956848]),
        mean_targets,
        std_targets,
        graph,
        Gamma_diagonal,
        h,
        initial_edges,
    )
    params, _ = opt.run(
        jnp.array([9.265, 1.918, 4.490]),
        mean_targets,
        std_targets,
        graph,
        Gamma_diagonal,
        h,
        initial_edges,
    )
    return params


def _grid(
    mean_targets: Array,
    std_targets: Array,
    graph: ConstantStencilGraph,
    Gamma_diagonal: Array,
    h: Array,
    initial_edges: Edges,
) -> Array:
    cs = _cartesian(jnp.linspace(1.0, 40.0, 1000), num=1)
    obj = lambda c: _objective(
        c, mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges
    )
    objs = batch_vmap(obj, cs, batch_size=5_000, progress=True)
    cs = cs[~jnp.isnan(objs)]

    objs = objs[~jnp.isnan(objs)]
    return cs[jnp.argmin(objs)]


P = ParamSpec("P")


def _value_and_jacfwd(f: Callable[P, Array]):
    jac = jax.jacfwd(f, argnums=0)

    def func(*args: P.args, **kwargs: P.kwargs) -> tuple[Array, Array]:
        return f(*args, **kwargs), jac(*args, **kwargs)

    return func


def _objective(
    c: Array,
    mean_targets: Array,
    std_targets: Array,
    graph: ConstantStencilGraph,
    Gamma_diagonal: Array,
    h: Array,
    initial_edges: Edges,
    loss: Literal["kl", "std_only", "mean_only"] = "std_only",
) -> Array:
    marginals = _run_mp(c, graph, Gamma_diagonal, h, initial_edges)

    if loss == "kl":
        mean_diff = (marginals.mean - mean_targets) ** 2
        var_targets = std_targets**2
        var_ratio = marginals.std**2 / var_targets
        kl = mean_diff / (2 * var_targets) + 0.5 * (var_ratio - 1 - jnp.log(var_ratio))
        return kl.sum()
    elif loss == "std_only":
        return ((marginals.std - std_targets) ** 2).sum()
    elif loss == "mean_only":
        return ((marginals.mean - mean_targets) ** 2).sum()


def _run_mp(
    c: Array,
    graph: ConstantStencilGraph,
    Gamma_diagonal: Array,
    h: Array,
    initial_edges: Edges,
) -> Marginals:
    config = Config(
        graph=graph,
        c=_create_c_stencil(graph, c),
        Gamma_diagonal=Gamma_diagonal,
        h=h,
        lr=0.6,
    )
    edges, marginals = message_passing.iterate(
        config, initial_edges, n_iterations=200, parallel=True, progress_bar=False
    )
    return marginals


def _create_c_stencil(
    graph: ConstantStencilGraph, values: Array
) -> ConstantStencilGraph:
    assert values.shape == (3,) or values.shape == (1,), f"Shape was {values.shape}"
    if values.shape == ():
        values = jnp.array([values[0], values[0], values[0]])

    weights = graph.stencil_mask.astype(jnp.float32)

    weights = weights.at[2, 0].set(values[0])
    weights = weights.at[0, 2].set(values[0])
    weights = weights.at[4, 2].set(values[0])
    weights = weights.at[2, 4].set(values[0])

    weights = weights.at[2, 1].set(values[1])
    weights = weights.at[1, 2].set(values[1])
    weights = weights.at[3, 2].set(values[1])
    weights = weights.at[2, 3].set(values[1])

    weights = weights.at[1, 1].set(values[2])
    weights = weights.at[3, 3].set(values[2])
    weights = weights.at[1, 3].set(values[2])
    weights = weights.at[3, 1].set(values[2])

    return ConstantStencilGraph(weights=weights)


def _cartesian(xs: Array, num: int) -> Array:
    zs = [xs] * num
    return jnp.stack(jnp.meshgrid(*zs), -1).reshape(-1, num)


if __name__ == "__main__":
    main()