from typing import List import rt_search_based.search_algorithms.heuristic_brute_force as search from rt_search_based.datasets.datasets import DatasetItem from rt_search_based.fitness_functions.fitness_functions import FitnessFunction from rt_search_based.models.classifiers import Classifier from rt_search_based.strategies.strategies import Strategy from rt_search_based.transformations import norms_and_colors from rt_search_based.transformations.stickers import Color from torch import Tensor class HeuristicBruteForceStrategyOptimal(Strategy): """ Strategy that exhaustively searches for the optimal solution for a single sticker (might take a long time) """ def __init__( self, fitness_function: FitnessFunction, classifier: Classifier, sticker_colors: List[Color] = None, **kwargs, ) -> None: super().__init__( fitness_function, classifier, sticker_colors, **kwargs, ) self.search_space = search.gen_search_space( 64, # upper_limit (the minimal stickers I have seen are way smaller than 100) 100, norms_and_colors.area_norm, # use all colors self.sticker_colors, batch_size=1024, ) def search_for_sticker(self, dataset_item: DatasetItem) -> Tensor: solution = search.search_for_upper_bound( self.search_space, self.classifier, dataset_item.image.to(self.device), dataset_item.label, # how to sample the batches lambda _: 0, # empircally: aggressive := True often yields the same results as False but is faster # (especially for attacks where the minimal sticker is rather large) # for an exhaustive search keep it off aggressive=False, verbose=True, # make sure this is below 256 as overflows can happen (dtype := uint8 in search space) upper_bound=100, ) return solution class HeuristicBruteForceStrategy(Strategy): """ Strategy that searches with the brute force approach in a non-exhaustive way to increase performance """ def __init__( self, fitness_function: FitnessFunction, classifier: Classifier, sticker_colors: List[Color] = None, **kwargs, ) -> None: super().__init__( fitness_function, classifier, sticker_colors, **kwargs, ) self.search_space = search.gen_search_space( 64, # upper_limit (the minimal stickers I have seen are way smaller than 100) 100, lambda coordinates: norms_and_colors.area_norm(coordinates) ** 1.5 * norms_and_colors.l2_norm(coordinates, (32, 32, 32, 32)), # use all colors self.sticker_colors, batch_size=1024, ) def search_for_sticker(self, dataset_item: DatasetItem) -> Tensor: solution = search.search_for_upper_bound( self.search_space, self.classifier, dataset_item.image.to(self.device), dataset_item.label, # how to sample the batches lambda x: int(len(x) * 0.1), # empircally: aggressive := True often yields the same results as False but is faster # (especially for attacks where the minimal sticker is rather large) # for an exhaustive search keep it off aggressive=True, verbose=True, # make sure this is below 256 as overflows can happen (dtype := uint8 in search space) upper_bound=100, ) return solution class HeuristicBruteForceStrategyCenter(Strategy): """ Strategy that searches all single stickers that overlay the center of the image. This search space is only 1/16 of the full search space as coordinates are only in the range of [0, 31] and are later transformed to be relative to the center by a translation function. """ def __init__( self, fitness_function: FitnessFunction, classifier: Classifier, sticker_colors: List[Color] = None, **kwargs, ) -> None: super().__init__( fitness_function, classifier, sticker_colors, **kwargs, ) self.search_space = search.gen_search_space( 32, # upper_limit (the minimal stickers I have seen are way smaller than 100) 100, norms_and_colors.area_norm, # use all colors self.sticker_colors, batch_size=1024, translator=norms_and_colors.center_translator, ) def search_for_sticker(self, dataset_item: DatasetItem) -> Tensor: solution = search.search_for_upper_bound( self.search_space, self.classifier, dataset_item.image.to(self.device), dataset_item.label, # how to sample the batches lambda x: int(len(x) * 0.5), # empircally: aggressive := True often yields the same results as False but is faster # (especially for attacks where the minimal sticker is rather large) # for an exhaustive search keep it off aggressive=False, verbose=True, # make sure this is below 256 as overflows can happen (dtype := uint8 in search space) upper_bound=100, ) return solution