traffic-sign-classifier-robustness-testing / rt_search_based / examples / jmetalpy_all_example.py
jmetalpy_all_example.py
Raw
import logging
import sys

from rt_search_based.database.database import Database
from rt_search_based.datasets.datasets import MiniGtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import (
    FooledAreaFitnessFunction,
    PenalizationFitnessFunction,
)
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.strategies.jmetalpy_strategies import (
    JMetalPyGeneticStrategy,
    JMetalPyLocalSearchStrategy,
)
from rt_search_based.transformations.stickers import Color


def run_example() -> None:
    """This is an example on how JMetalPy strategies can be evaluated"""

    classifier2test = Classifier2Test_1()  # classifier to be used

    dataset = MiniGtsrbDataset(50)  # first 50 images from dataset

    database = Database()

    sticker_colors = [Color.BLACK, Color.WHITE]  # colors to be used by strategy
    for i in range(1, 3):
        for fitness_function_class in [
            PenalizationFitnessFunction,
            FooledAreaFitnessFunction,
        ]:
            for Strategy in [
                JMetalPyGeneticStrategy,
                JMetalPyLocalSearchStrategy,
            ]:
                fitness_function = fitness_function_class(classifier2test)

                jmetalpy_genetic_strategy = Strategy(
                    fitness_function,
                    classifier2test,
                    sticker_colors,
                    sticker_count=i,
                )
                jmetalpy_genetic_strategy.evaluate(dataset, database)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("info.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    run_example()