traffic-sign-classifier-robustness-testing / rt_search_based / examples / fitness_function_visualization_example.py
fitness_function_visualization_example.py
Raw
import logging
import sys
from typing import cast

from rt_search_based.datasets.datasets import DatasetItem, MiniGtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import (
    PenalizationFitnessFunction,
)
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.transformations.stickers import Color


def run_example() -> None:
    """
    This is an example on how to run the FitnessFunction visualization
    """
    classifier2test = Classifier2Test_1()  # classifier to be used

    dataset = MiniGtsrbDataset(1)  # first image from dataset

    sticker_colors = [Color.BLACK, Color.WHITE]  # colors to be used by strategy

    for FitnessFunction in [PenalizationFitnessFunction]:
        dataset_item = cast(DatasetItem, dataset[0])
        FitnessFunction(classifier2test).visualize_fitness_function_for_image(
            dataset_item, sticker_colors
        )


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("info.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    run_example()