traffic-sign-classifier-robustness-testing / rt_search_based / examples / bruteforce_all_example.py
bruteforce_all_example.py
Raw
import logging
import sys

import torch
from rt_search_based.database.database import Database
from rt_search_based.datasets.datasets import MiniGtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import BasicAreaFitnessFunction
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.strategies.heuristic_brute_force_strategy import (
    HeuristicBruteForceStrategy,
    HeuristicBruteForceStrategyCenter,
    HeuristicBruteForceStrategyOptimal,
)
from rt_search_based.transformations.stickers import Color


def run_example():
    """This is an example on how brute force strategies can be evaluated"""

    classifier2test = Classifier2Test_1()  # classifier to be used

    dataset = MiniGtsrbDataset(50)  # first n images from dataset

    database = Database()

    sticker_colors = [Color.BLACK, Color.WHITE]  # colors to be used by strategy

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    classifier2test.classifier.to(device)

    for Strategy in [
        HeuristicBruteForceStrategyCenter,
        HeuristicBruteForceStrategy,
        # This takes by far the longest so be careful
        HeuristicBruteForceStrategyOptimal,
    ]:
        fitness_function = BasicAreaFitnessFunction(classifier2test)

        brute_force_strategy = Strategy(
            fitness_function, classifier2test, sticker_colors
        )

        brute_force_strategy.evaluate(dataset, database)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler("info.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    run_example()