import logging import sys import torch from rt_search_based.database.database import Database from rt_search_based.datasets.datasets import MiniGtsrbDataset from rt_search_based.fitness_functions.fitness_functions import BasicAreaFitnessFunction from rt_search_based.models.classifiers import Classifier2Test_1 from rt_search_based.strategies.heuristic_brute_force_strategy import ( HeuristicBruteForceStrategy, HeuristicBruteForceStrategyCenter, HeuristicBruteForceStrategyOptimal, ) from rt_search_based.transformations.stickers import Color def run_example(): """This is an example on how brute force strategies can be evaluated""" classifier2test = Classifier2Test_1() # classifier to be used dataset = MiniGtsrbDataset(50) # first n images from dataset database = Database() sticker_colors = [Color.BLACK, Color.WHITE] # colors to be used by strategy device = torch.device("cuda" if torch.cuda.is_available() else "cpu") classifier2test.classifier.to(device) for Strategy in [ HeuristicBruteForceStrategyCenter, HeuristicBruteForceStrategy, # This takes by far the longest so be careful HeuristicBruteForceStrategyOptimal, ]: fitness_function = BasicAreaFitnessFunction(classifier2test) brute_force_strategy = Strategy( fitness_function, classifier2test, sticker_colors ) brute_force_strategy.evaluate(dataset, database) if __name__ == "__main__": logging.basicConfig( level=logging.INFO, handlers=[ logging.FileHandler("info.log"), logging.StreamHandler(sys.stdout), ], ) run_example()