traffic-sign-classifier-robustness-testing / rt_search_based / examples / pymoo_all_example.py
pymoo_all_example.py
Raw
import logging
import sys

from rt_search_based.database.database import Database
from rt_search_based.datasets.datasets import MiniGtsrbDataset, GtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import (
    AreaFitnessFunction,
    BasicAreaFitnessFunction,
    FooledAreaFitnessFunction,
    PenalizationFitnessFunction,
    PositionFitnessFunction,
    RatioFitnessFunction,
)
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.strategies.pymoo_strategies import PyMooGeneticStrategy
from rt_search_based.transformations.stickers import Color


def run_example():
    """This is an example on how pymoo strategies can be evaluated"""

    classifier2test = Classifier2Test_1()

    dataset = GtsrbDataset()

    database = Database()

    sticker_colors = [Color.BLACK, Color.WHITE]

    for number_of_stickers in range(1, 2):
        for fitness_function_class in [
#            AreaFitnessFunction,
#            BasicAreaFitnessFunction,
#            FooledAreaFitnessFunction,
            PenalizationFitnessFunction,
        ]:
            fitness_function = fitness_function_class(classifier2test)

            # this is needed as these fitness functions require the constraint that the model fails
            need_constraints = fitness_function_class in [
                AreaFitnessFunction,
                RatioFitnessFunction,
                PositionFitnessFunction,
            ]

            pymoo_genetic_strategy = PyMooGeneticStrategy(
                fitness_function,
                classifier2test,
                sticker_colors,
                sticker_count=number_of_stickers,
                need_constraints=need_constraints,
            )

            pymoo_genetic_strategy.evaluate(dataset, database)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("info.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    run_example()