traffic-sign-classifier-robustness-testing / rt_search_based / database / database.py
database.py
Raw
import logging
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Dict,
    Generator,
    List,
    Literal,
    Optional,
    Set,
    Tuple,
    Union,
)

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from rt_search_based.database.csv_manager import CsvManager
from rt_search_based.datasets.datasets import Dataset, GtsrbDataset
from rt_search_based.transformations import stickers
from rt_search_based.utils import images as image_utils

# typing runs into circular imports and thus causing errors by the interpreter
# for further information see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
# note this is the reason why the type annotations for Strategy and FitnessFunction are in ""
if TYPE_CHECKING:
    from rt_search_based.fitness_functions.fitness_functions import FitnessFunction
    from rt_search_based.strategies.strategies import Strategy


class Database:
    """
    Database that manages multiple Csv_managers (that each manages a csv file).
    Here you can enter your Algorithms to have their search results stored.
    We define that algorithm := (FitnessFunction, Strategy)

    You can address the items in the following way:

        with Database() as db:

            db[: , FitnessFunction, Strategy]

            # access data for img 2:
                db[2, FitnessFunction, Strategy]

            # access data for all strategies + fitnessfunctions for all imgs:
                db[:]

            # data for all strategies + fitnessfunctions for img 2,3,4
                db[2:5]
    """

    def __init__(
        self,
        count_imgs: int = 12630,
        root_folder: Path = Path("rt_search_based/database/csv/"),
    ):
        self.strategies: Dict[Tuple, CsvManager] = {}
        self.strategy_updated: Dict[CsvManager, bool] = {}
        self.count_imgs = count_imgs
        self.root_folder = root_folder
        pathlist: Generator[Path, None, None] = self.root_folder.glob("**/*.csv")
        # only load dataset if necessary for performance
        self.dataset: Optional[Dataset] = None

        for path in pathlist:
            # get the file/folder names from the path
            str_fitness_function: str = path.parts[-2]
            # -4 is used to remove ".csv" from path string
            str_strategy: str = path.parts[-1][:-4]
            key = (str_fitness_function, str_strategy)
            self.strategies[key] = CsvManager(self.count_imgs, path)

    def write_back(self):
        """
        force the CsvManagers to write their datatensor from memory back to the filesystem
        """
        for csv_manager in self.strategies.values():
            # only write back if the data has changed
            if self.strategy_updated[csv_manager]:
                csv_manager.write_csv()

    def __enter__(self):
        """read all the csv files into memory (torch tensors)

        Returns:
            Database: the database object to work with
        """
        for csv_manager in self.strategies.values():
            csv_manager.read_csv()

            # save a hash (if that hash doesn't change we wont need to write the tensor back to memory)
            self.strategy_updated[csv_manager] = False

        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        """write the information from memory (torch tensors) back to csv files"""

        self.write_back()

    def new(
        self,
        fitness_function: "FitnessFunction",
        strategy: "Strategy",
        custom_fields: Optional[List[str]] = None,
    ) -> None:
        """register new strategy for which one stores metrics for each img in the dataset

        Args:
            fitness_function (FitnessFunction): fitness_function for which you want to store metrics
            strategy (Strategy): strategy for which you want to store metrics
        """

        # if the fitness_function wasn't used before we have to make a directory first
        directory = Path(self.root_folder, str(fitness_function))

        directory.mkdir(parents=True, exist_ok=True)
        key = (str(fitness_function), str(strategy))

        fields = custom_fields if custom_fields is not None else strategy.fields

        self.strategies[key] = CsvManager(
            self.count_imgs,
            Path(self.root_folder, str(fitness_function), f"{str(strategy)}.csv"),
            fields,
        )

        # initialize file with all -1 for "no information yet"
        self.strategies[key].initialize_csv()
        # read file into memory
        self.strategies[key].read_csv()

        # initialize hash
        self.strategy_updated[self.strategies[key]] = False

    def delete(
        self,
        fitness_function: Union[str, "FitnessFunction"],
        strategy: Union[str, "Strategy"],
    ) -> None:
        """delete strategy with fitness_function from database (erases csv file)

        Args:
            strategy (Strategy): strategy to target
            fitness_function (FitnessFunction): fitness_function to target
        """
        key = (str(fitness_function), str(strategy))
        self.strategies[key].delete()
        self.strategies.pop(key)

    def purge(self) -> None:
        """delete all strategies from database (erases all csv files)"""

        for csv_manager in self.strategies.values():
            csv_manager.delete()
        self.strategies.clear()

    def __getitem__(
        self,
        index_tuple: Union[
            Union[int, slice],
            Tuple[
                Union[int, slice],
                Union["FitnessFunction", str],
                Union["Strategy", str],
            ],
        ],
    ):

        if isinstance(index_tuple, (int, slice)):
            tables = {}
            for key, csv_manager in self.strategies.items():
                tables[key] = csv_manager.data_tensor[index_tuple]
            return tables

        if len(index_tuple) == 3:
            key = (str(index_tuple[1]), str(index_tuple[2]))
            table = self.strategies[key][index_tuple[0]]
            return table

        # else
        raise ValueError

    def __setitem__(
        self,
        index_tuple: Tuple[int, Union["FitnessFunction", str], Union["Strategy", str]],
        data: torch.Tensor,
    ) -> None:

        if len(index_tuple) == 3:
            key = (str(index_tuple[1]), str(index_tuple[2]))
            self.strategy_updated[self.strategies[key]] = True
            self.strategies[key].data_tensor[index_tuple[0]] = data
        else:
            raise ValueError

    def get_saved_fitness_function_names(
        self, strategy_name: Union[str, None] = None
    ) -> Set[str]:
        """
        returns set with names of all fitness functions with stored data
        if strategy_name is specified, only returns fitness functions used by given strategy
        """
        CSV_PARENT_DIR = Path(self.root_folder)
        fitness_function_dirs = [
            dir_ for dir_ in CSV_PARENT_DIR.glob("*/") if dir_.is_dir()
        ]

        if strategy_name:
            # filter fitness function dirs to only include dirs that contain a csv with the name of the given strategy

            def strategies_in_fitness_dir(dir_: Path):
                return map(lambda file: file.stem, dir_.glob("*.csv"))

            fitness_function_dirs = list(
                filter(
                    lambda dir: (strategy_name in strategies_in_fitness_dir(dir)),
                    fitness_function_dirs,
                )
            )

        return {dir_.name for dir_ in fitness_function_dirs}

    def get_saved_strategy_names(
        self, fitness_function_name: Union[str, None] = None
    ) -> Set[str]:
        """
        returns set with names of all strategies with stored data for a specific fitness function
        if fitness_function_name is specified, only returns strategy names of strategies that used given fitness function
        """
        parent_dir = Path(self.root_folder)

        if fitness_function_name:
            parent_dir = parent_dir / fitness_function_name

        strategy_files = parent_dir.glob("**/*.csv")
        return {file.stem for file in strategy_files}

    def get_report(
        self,
        fitness_function: Union["FitnessFunction", str],
        strategy: Union["Strategy", str],
    ) -> Dict[str, torch.Tensor]:
        """
        get the mean and variance of area, runtime and fitness_value associated with the key := (strategy, fitness_function)
        """

        # retrieves fitness score, runtime (in ns), sticker area of all images for that specific strategy x fitness function
        targets = self[:, fitness_function, strategy][:, :3]

        targets = targets[targets.sum(dim=1) >= 0].double()

        report: dict[str, torch.Tensor] = {}

        report["Mean Sticker Area"] = torch.mean(targets, dim=0)
        report["Sticker Area Variance"] = torch.var(targets, dim=0)
        report["Sticker Area SD"] = torch.std(targets, dim=0)
        report["Median Sticker Area"] = torch.median(targets, dim=0).values
        report["Sticker Area Mode"] = torch.mode(targets, dim=0).values

        return report

    def get_image_path(
        self,
        index: int,
        strategy_name: str = None,
        fitness_function_name: str = None,
    ) -> Path:
        """
        returns the path of a stickered image of a strategy x fitness function or the unstickered image
        """

        # image nomenclature:
        # {index}_{strategy_name}_{fitness_function_name}.png
        # {index}.png for unprocessed images

        PARENT_IMAGE_DIR = Path("./rt_search_based/imgs/png_image_cache")

        if strategy_name and fitness_function_name:
            image_name = f"{index}_{strategy_name}_{fitness_function_name}.png"
        elif strategy_name or fitness_function_name:
            raise ValueError(
                "requires both a strategy_name and a fitness_function_name"
            )
        else:
            image_name = f"{index}.png"

        image_path = PARENT_IMAGE_DIR / image_name

        if not image_path.exists():

            try:

                # only load dataset if necessary for performance
                if not self.dataset:
                    self.dataset = GtsrbDataset()

                image = self.dataset[index].image  # type: ignore

                if strategy_name and fitness_function_name:

                    sticker_props = self[index, fitness_function_name, strategy_name][
                        2:
                    ]

                    # saving a stickered image
                    stickered_image = stickers.add_multi_sticker_to_image(
                        sticker_props, image
                    )
                    image_utils.save_image(stickered_image, image_path)
                else:
                    # saving a non-stickered image

                    image_utils.save_image(image, image_path)

            except ValueError as ve:
                logging.error(ValueError, "trying to save image in cache", ve)

            except Exception as e:
                logging.exception(e)

        return image_path

    def get_diagram_image_path(
        self,
        fitness_function_name: str,
        diagram_type: Literal["distribution", "box-plot"],
    ) -> Path:
        """
        returns the path of diagram showing the distribution of areas between strategy functions of a given fitness function
        """
        logging.info("database: getting an image path")

        # image nomenclature:
        # {diagram_type}_diagram_for_{fitness_function_name}.png

        PARENT_IMAGE_DIR = Path("./rt_search_based/imgs/png_diagram_cache")

        if not fitness_function_name:
            raise ValueError("requires fitness_function_name")

        image_name = f"{diagram_type}_for_{fitness_function_name}.png"

        image_path = PARENT_IMAGE_DIR / image_name

        logging.info(f"database: looking for img at: {image_path}")

        if not image_path.exists():

            plt.figure(figsize=(9, 6), dpi=500)

            if diagram_type == "distribution":

                logging.info("image does not exist, trying to create it")

                sns.set_style("white")
                kwargs = dict(
                    kde_kws={"linewidth": 2}, hist=False
                )  # hist_kws={'alpha':.6},

                for (
                    (fitness_function, strategy),
                    csv_manager,
                ) in self.strategies.items():
                    if fitness_function_name == fitness_function:
                        # get minimal sticker area vector for the fitness_function and strategy
                        targets = csv_manager.data_tensor[:, 2]

                        # filter uninitialized values
                        targets = targets[targets != -1]
                        # convert tensor to ndarray
                        targets = np.asarray(targets)  # type: ignore

                        # plot the distribution (and convert tensor to ndarray)
                        sns.distplot(np.asarray(targets), label=f"{strategy}", **kwargs)

                plt.xlim(left=0)
                plt.legend()
                plt.xlabel("Sticker Area")

            elif diagram_type == "box-plot":

                data = []
                strategy_names = []

                for (
                    (fitness_function, strategy),
                    csv_manager,
                ) in self.strategies.items():
                    if fitness_function_name == fitness_function:
                        # get minimal sticker area vector for the fitness_function and strategy
                        targets = csv_manager.data_tensor[:, 2]

                        # filter uninitialized values
                        targets = targets[targets != -1]
                        # convert tensor to ndarray
                        targets = np.asarray(targets)  # type: ignore

                        # add values to data (and convert tensor to ndarray)
                        data.append(np.asarray(targets))
                        strategy_names.append(strategy)

                sns.boxplot(data=data, orient="h").set(xlabel="Area")
                plt.yticks(np.arange(len(data)), strategy_names)

            # save plot
            plt.title(f"{fitness_function_name} Sticker Area {diagram_type}")
            plt.tight_layout()
            image_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(image_path)

        logging.info(f"returning image path: {image_path=}")

        return image_path