da-message-passing / src / damp / ground_truth_cache.py
ground_truth_cache.py
Raw
import itertools
from pathlib import Path
from time import time
from typing import Iterator

import numpy as np
from numpy import ndarray
from numpy.random import default_rng

from damp import gp

DIRECTORY = Path("outputs/ground_truths")
BASE_SEED = 901823


def load_or_gen(prior: gp.Prior, start_at: int = 0) -> Iterator[ndarray]:
    def gen() -> Iterator[ndarray]:
        for seed in _get_available_seeds(prior):
            print(f"Loading ground truth with seed {seed}")
            yield _load_one(prior, seed)

        while True:
            print("Run out of cached ground truths, generating one...")
            start_time = time()
            gt, seed = _gen_one(prior)
            duration = time() - start_time
            print(f"Generated grouth truth with seed {seed} in {duration:.2f} seconds")
            _save(prior, gt, seed)
            yield gt

    return itertools.islice(gen(), start_at, None)


def _get_available_seeds(prior: gp.Prior) -> Iterator[int]:
    if not DIRECTORY.exists():
        return

    for path in DIRECTORY.iterdir():
        if path.suffix == ".npy" and f"{prior.name}_" in path.stem:
            split_name = path.stem.split("_")
            yield int(split_name[split_name.index("seed") + 1])


def _load_one(prior: gp.Prior, seed: int) -> ndarray:
    return np.load(DIRECTORY / f"{_get_name(prior, seed)}.npy")


def _gen_one(prior: gp.Prior) -> tuple[ndarray, int]:
    available_seeds = set(_get_available_seeds(prior))
    if len(available_seeds) == 0:
        seed = BASE_SEED
    else:
        seed = max(available_seeds) + 1

    numpy_rng = default_rng(seed)
    return gp.sample_prior(numpy_rng, prior), seed


def _save(prior: gp.Prior, gt: ndarray, seed: int) -> None:
    DIRECTORY.mkdir(exist_ok=True, parents=True)
    output_path = DIRECTORY / f"{_get_name(prior, seed)}.npy"
    np.save(output_path, gt)


def _get_name(prior: gp.Prior, seed: int) -> str:
    return f"{prior.name}_seed_{seed}"