import functools import itertools import math from pathlib import Path from typing import Any import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import pandas as pd from numpy import ndarray from numpy.random import default_rng from damp import gp, ground_truth_cache, inla_bridge, message_passing from damp.gp import Obs, Shape from damp.graph import graph_and_diagonal_from_precision_matrix from damp.message_passing import Config from damp.metrics import rmse from experiments import plotting def main() -> None: plotting.configure_matplotlib() plt.figure(figsize=(plotting.HALF_WIDTH, 1.3)) save_dir = Path("outputs/lr_c_convergence") save_dir.mkdir(parents=True, exist_ok=True) lrs = [0.6, 0.7, 0.8] cs = [-10, -2, -1, 1, 5, 10, 20] grid_sizes = [128, 256, 512] hyperparameters = [ {"lr": lr, "c": c, "grid_size": s} for c, lr, s in itertools.product(cs, lrs, grid_sizes) ] y_top = 0.5 inla_rmses = {} for grid_size in grid_sizes: output_path = save_dir / f"inla_{grid_size}.npy" if output_path.exists(): inla_result = np.load(output_path, allow_pickle=True).item() else: inla_result = _run_inla(grid_size) np.save(output_path, inla_result) inla_rmses[grid_size] = inla_result["rmse"].item() results = [] for i, hyps in enumerate(hyperparameters): output_path = save_dir / f"{_dict_to_str(hyps)}.npy" if output_path.exists(): result = np.load(output_path, allow_pickle=True).item() else: result = _run_mp(hyps) np.save(output_path, result) mid_rmse = ( result["rmses"][result["steps"] == 4000].item() / inla_rmses[grid_size] ) results.append({"mid rmse": mid_rmse} | hyps | result) df = pd.DataFrame(results) to_print = df[["mid rmse", "lr", "c", "grid_size"]] to_print = to_print.pivot(index=["grid_size", "lr"], columns="c", values="mid rmse") to_print.columns.name = None to_print.index = to_print.index.map( lambda x: (f"${x[0]} \\times {x[0]}$", f"{x[1]:.1f}") ) to_print = to_print.map(lambda x: f"{x:.2f}" if not np.isnan(x) else "-") with pd.option_context("display.max_rows", None): print(to_print) print(to_print.to_latex(float_format="%.2f")) plot_grid_size = 256 to_plot = df[ (df["grid_size"] == plot_grid_size) & (df["c"].isin([-10, -1, 1, 5, 10, 20])) ] to_plot = to_plot.loc[to_plot.groupby("c")["c"].idxmax()] for i, (_, row) in enumerate(to_plot.iterrows()): steps, rmses = row["steps"], row["rmses"] bad = np.isnan(rmses) c = row["c"] c_str = c if c < 100.0 else f"10^{round(math.log(c, 10))}" label = f"$c={c_str}$" color = f"C{i}" if np.any(bad): plt.scatter([-10], [-10], label=label, marker="x", color=color) else: plt.plot(steps[~bad], rmses[~bad], label=label, color=color) plt.axhline(inla_rmses[plot_grid_size], color="black") plt.legend(**plotting.squashed_legend_params, ncols=2) plt.xlim(left=0, right=4000) plt.ylim(bottom=inla_rmses[plot_grid_size] - 0.05, top=y_top) plt.xlabel("iterations", **plotting.squashed_label_params) plt.ylabel("RMSE", **plotting.squashed_label_params) plt.tight_layout(pad=0.2) plotting.save_fig("lr_c_convergence") plt.close() def _dict_to_str(d: dict[str, Any]) -> str: return "_".join(f"{k}_{v}" for k, v in sorted(d.items())) def _run_mp(hyps: dict[str, Any]) -> dict[str, ndarray]: prior, gt, obs_noise, obs = _get_prior(hyps["grid_size"]) posterior = gp.get_posterior(prior, obs, obs_noise) graph, Gamma_diagonal = graph_and_diagonal_from_precision_matrix( posterior.precision ) config = Config( graph=graph, Gamma_diagonal=jnp.array(Gamma_diagonal), h=jnp.array(posterior.shift), **{k: v for k, v in hyps.items() if k != "grid_size"}, ) initial_edges = message_passing.get_initial_edges(graph) history = message_passing.iterate_with_history( config, initial_edges, n_iterations=5000, save_every=50 ) steps = np.array([step for step, _, _ in history]) rmses = np.stack( [ rmse(marginals.mean.reshape(prior.interior_shape), gt) for _, _, marginals in history ] ) return {"steps": steps, "rmses": rmses} def _run_inla(grid_size: int) -> dict[str, ndarray]: prior, gt, _, obs = _get_prior(grid_size) mean, std = inla_bridge.run(prior, obs) return {"rmse": np.array(rmse(mean, gt))} @functools.cache def _get_prior(grid_size: int) -> tuple[gp.Prior, ndarray, float, Obs]: numpy_rng = default_rng(seed=1293123) prior = gp.get_prior(Shape(grid_size, grid_size)) ground_truth = next(ground_truth_cache.load_or_gen(prior)) obs_noise = 1e-3 obs = gp.choose_observations( numpy_rng, n_obs=round(prior.shape.flatten() * 0.05), ground_truth=ground_truth, obs_noise=obs_noise, ) return prior, ground_truth, obs_noise, obs if __name__ == "__main__": main()