from typing import Callable, Literal, ParamSpec import chex import jax import jax.numpy as jnp import matplotlib.pyplot as plt import skopt from jax import Array from jaxopt import LBFGS from numpy import ndarray from numpy.random import default_rng from skopt.plots import plot_convergence from damp import gp, ground_truth_cache, inla_bridge, message_passing from damp.gp import Shape from damp.graph import ConstantStencilGraph, graph_and_diagonal_from_precision_matrix from damp.jax_utils import batch_vmap from damp.message_passing import Config, Edges, Marginals def main() -> None: # Support having a c for each edge # See if we get any gradient signal for the cs # Or maybe: see if we get any gradient signal for the single c first... numpy_rng = default_rng(seed=1124) prior = gp.get_prior(Shape(10, 10)) ground_truth = next(ground_truth_cache.load_or_gen(prior)) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, n_obs=150, ground_truth=ground_truth, obs_noise=obs_noise ) mean_targets_numpy, std_targets_numpy = inla_bridge.run(prior, obs) mean_targets = jnp.array(mean_targets_numpy) std_targets = jnp.array(std_targets_numpy) posterior = gp.get_posterior(prior, obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision, prior.interior_shape ) initial_edges = message_passing.get_initial_edges(graph, prior.interior_shape) Gamma_diagonal = jnp.array(Gamma_diagonal).reshape(prior.interior_shape) h = jnp.array(posterior.shift).reshape(prior.interior_shape) optimal_cs = _grid( mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges ) # optimal_cs = _lbfgs( # mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges # ) # optimal_cs = _bayesopt( # mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges # ) print("Optimal:", ", ".join(f"{x:.10f}" for x in optimal_cs)) pred_marginals = _run_mp(optimal_cs, graph, Gamma_diagonal, h, initial_edges) fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) mean_min, mean_max = jnp.min(mean_targets), jnp.max(mean_targets) std_min, std_max = jnp.min(std_targets), jnp.max(std_targets) axes[0, 0].imshow(mean_targets) axes[1, 0].imshow(pred_marginals.mean, vmin=mean_min, vmax=mean_max) axes[0, 1].imshow(std_targets) axes[1, 1].imshow(pred_marginals.std, vmin=std_min, vmax=std_max) obs_xs, obs_ys = zip(*[(x - 1, y - 1) for (x, y), val in obs]) axes[0, 1].scatter(obs_ys, obs_xs, color="red") axes[1, 1].scatter(obs_ys, obs_xs, color="red") plt.savefig("plots/c_search.png") plt.close() chex.block_until_chexify_assertions_complete() def _bayesopt( mean_targets: Array, std_targets: Array, graph: ConstantStencilGraph, Gamma_diagonal: Array, h: Array, initial_edges: Edges, ) -> Array: def _obj(c: list[float]) -> ndarray: # c = jnp.array(ps) loss = _objective( jnp.array(c), mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges, ) if jnp.isnan(loss) or jnp.isinf(loss): return 100_000.0 else: return loss.item() result = skopt.gp_minimize( _obj, dimensions=[(9.0, 10.0), (1.5, 2.2), (4.0, 5.0)], n_jobs=6, verbose=True, n_initial_points=100, n_calls=200, ) assert result is not None plot_convergence(result) plt.savefig("plots/c_bo_convergence.png") plt.close() return jnp.array(result.x) def _lbfgs( mean_targets: Array, std_targets: Array, graph: ConstantStencilGraph, Gamma_diagonal: Array, h: Array, initial_edges: Edges, ) -> Array: obj = _value_and_jacfwd(_objective) opt = LBFGS(obj, value_and_grad=True, verbose=False) v, g = obj( jnp.array([9.2653055191, 1.9183673859, 4.4897956848]), mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges, ) params, _ = opt.run( jnp.array([9.265, 1.918, 4.490]), mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges, ) return params def _grid( mean_targets: Array, std_targets: Array, graph: ConstantStencilGraph, Gamma_diagonal: Array, h: Array, initial_edges: Edges, ) -> Array: cs = _cartesian(jnp.linspace(1.0, 40.0, 1000), num=1) obj = lambda c: _objective( c, mean_targets, std_targets, graph, Gamma_diagonal, h, initial_edges ) objs = batch_vmap(obj, cs, batch_size=5_000, progress=True) cs = cs[~jnp.isnan(objs)] objs = objs[~jnp.isnan(objs)] return cs[jnp.argmin(objs)] P = ParamSpec("P") def _value_and_jacfwd(f: Callable[P, Array]): jac = jax.jacfwd(f, argnums=0) def func(*args: P.args, **kwargs: P.kwargs) -> tuple[Array, Array]: return f(*args, **kwargs), jac(*args, **kwargs) return func def _objective( c: Array, mean_targets: Array, std_targets: Array, graph: ConstantStencilGraph, Gamma_diagonal: Array, h: Array, initial_edges: Edges, loss: Literal["kl", "std_only", "mean_only"] = "std_only", ) -> Array: marginals = _run_mp(c, graph, Gamma_diagonal, h, initial_edges) if loss == "kl": mean_diff = (marginals.mean - mean_targets) ** 2 var_targets = std_targets**2 var_ratio = marginals.std**2 / var_targets kl = mean_diff / (2 * var_targets) + 0.5 * (var_ratio - 1 - jnp.log(var_ratio)) return kl.sum() elif loss == "std_only": return ((marginals.std - std_targets) ** 2).sum() elif loss == "mean_only": return ((marginals.mean - mean_targets) ** 2).sum() def _run_mp( c: Array, graph: ConstantStencilGraph, Gamma_diagonal: Array, h: Array, initial_edges: Edges, ) -> Marginals: config = Config( graph=graph, c=_create_c_stencil(graph, c), Gamma_diagonal=Gamma_diagonal, h=h, lr=0.6, ) edges, marginals = message_passing.iterate( config, initial_edges, n_iterations=200, parallel=True, progress_bar=False ) return marginals def _create_c_stencil( graph: ConstantStencilGraph, values: Array ) -> ConstantStencilGraph: assert values.shape == (3,) or values.shape == (1,), f"Shape was {values.shape}" if values.shape == (): values = jnp.array([values[0], values[0], values[0]]) weights = graph.stencil_mask.astype(jnp.float32) weights = weights.at[2, 0].set(values[0]) weights = weights.at[0, 2].set(values[0]) weights = weights.at[4, 2].set(values[0]) weights = weights.at[2, 4].set(values[0]) weights = weights.at[2, 1].set(values[1]) weights = weights.at[1, 2].set(values[1]) weights = weights.at[3, 2].set(values[1]) weights = weights.at[2, 3].set(values[1]) weights = weights.at[1, 1].set(values[2]) weights = weights.at[3, 3].set(values[2]) weights = weights.at[1, 3].set(values[2]) weights = weights.at[3, 1].set(values[2]) return ConstantStencilGraph(weights=weights) def _cartesian(xs: Array, num: int) -> Array: zs = [xs] * num return jnp.stack(jnp.meshgrid(*zs), -1).reshape(-1, num) if __name__ == "__main__": main()