import itertools from pathlib import Path from time import time from typing import Iterator import numpy as np from numpy import ndarray from numpy.random import default_rng from damp import gp DIRECTORY = Path("outputs/ground_truths") BASE_SEED = 901823 def load_or_gen(prior: gp.Prior, start_at: int = 0) -> Iterator[ndarray]: def gen() -> Iterator[ndarray]: for seed in _get_available_seeds(prior): print(f"Loading ground truth with seed {seed}") yield _load_one(prior, seed) while True: print("Run out of cached ground truths, generating one...") start_time = time() gt, seed = _gen_one(prior) duration = time() - start_time print(f"Generated grouth truth with seed {seed} in {duration:.2f} seconds") _save(prior, gt, seed) yield gt return itertools.islice(gen(), start_at, None) def _get_available_seeds(prior: gp.Prior) -> Iterator[int]: if not DIRECTORY.exists(): return for path in DIRECTORY.iterdir(): if path.suffix == ".npy" and f"{prior.name}_" in path.stem: split_name = path.stem.split("_") yield int(split_name[split_name.index("seed") + 1]) def _load_one(prior: gp.Prior, seed: int) -> ndarray: return np.load(DIRECTORY / f"{_get_name(prior, seed)}.npy") def _gen_one(prior: gp.Prior) -> tuple[ndarray, int]: available_seeds = set(_get_available_seeds(prior)) if len(available_seeds) == 0: seed = BASE_SEED else: seed = max(available_seeds) + 1 numpy_rng = default_rng(seed) return gp.sample_prior(numpy_rng, prior), seed def _save(prior: gp.Prior, gt: ndarray, seed: int) -> None: DIRECTORY.mkdir(exist_ok=True, parents=True) output_path = DIRECTORY / f"{_get_name(prior, seed)}.npy" np.save(output_path, gt) def _get_name(prior: gp.Prior, seed: int) -> str: return f"{prior.name}_seed_{seed}"