traffic-sign-classifier-robustness-testing / rt_search_based / utils / images.py
images.py
Raw
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch


def save_image(image: torch.Tensor, location: Path):
    """Saves image to specified location"""

    location.parent.mkdir(parents=True, exist_ok=True)

    logging.info(f"img_utils: saving image {location=}")

    image = np.moveaxis(np.array(image), 0, 2)  # type: ignore

    plt.imsave(location, image)