import logging import time from abc import ABC, abstractmethod from typing import List import torch from rt_search_based.datasets.datasets import DatasetItem from rt_search_based.fitness_functions.fitness_functions import FitnessFunction from rt_search_based.models.classifiers import Classifier from rt_search_based.transformations.stickers import Color class Strategy(ABC): """ A Strategy is the abstract concept of a search algorithm that can search an image for the minimal sticker(s) that fools a DL classifier. For this it can utilize a Fitnessfunction thats helps to guide the search. """ def __init__( self, fitness_function: FitnessFunction, classifier: Classifier, sticker_colors: List[Color] = None, sticker_count=1, **kwargs, ) -> None: self.fitness_function = fitness_function self.classifier = classifier self.sticker_colors = ( [Color.BLACK] if sticker_colors is None else sticker_colors ) self.sticker_count = sticker_count self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Used by repr function to recreate strategy self.init_args = { "fitness_function": repr(self.fitness_function), "classifier": repr(self.classifier), "sticker_colors": repr(self.sticker_colors), } self.init_args.update({key: repr(value) for key, value in kwargs.items()}) # specify the column descriptors later used in the csv files in the database self.fields: List[str] = ["fitness", "runtime", "area"] # add properties column for each sticker for i in range(self.sticker_count): self.fields.extend( [ f"x1_{i}", f"y1_{i}", f"x2_{i}", f"y2_{i}", f"r_{i}", f"g_{i}", f"b_{i}", ] ) @abstractmethod def search_for_sticker(self, dataset_item: DatasetItem) -> torch.Tensor: """ Finds optimal stickers for the given dataset_item and returns the properties of the minimal sticker """ def evaluate(self, dataset, database): """ Perform the search strategies for each dataset_item in the dataset. Writes values to database if they are better (smaller area) than the previously stored ones. """ with database as db: key = str(self.fitness_function), str(self) # add new csv_manager to database if not already present if key not in db.strategies.keys(): db.new(self.fitness_function, self) for dataset_item in dataset: start = time.perf_counter_ns() sticker_properties = self.search_for_sticker(dataset_item) runtime = time.perf_counter_ns() - start current = db[dataset_item.index, self.fitness_function, self] # at index 3 the area is stored (only store if no data is present == -1 or the old solution was worse) if sticker_properties[0] != -1 and ( current[3] == -1 or current[3] > sticker_properties[0] ): new_record = self.initialize_record( dataset_item, runtime, sticker_properties ) db[ dataset_item.index, str(self.fitness_function), str(self), ] = new_record db.write_back() else: logging.info( f"old value was already better or equally good - {current[3]=} {sticker_properties[0]=}" ) def initialize_record(self, dataset_item, runtime, sticker_properties): """ Creates record to be saved in database of the structure [fitness_value, runtime, area, *[x1, y1, x2, y2, r, g, b]*sticker_count] """ return torch.cat( [ torch.tensor( [ self.fitness_function(sticker_properties, dataset_item), runtime, ] ), # Need to make sure all tensors are on the same device before concatinating sticker_properties.to(device="cpu"), ] ) def visualize_convergence(self): """ Displays visual of how the strategy converges """ def __repr__(self) -> str: return f"{type(self).__name__}({', '.join([f'{key}={value}' for key, value in self.init_args.items()])})" def __eq__(self, __o: object) -> bool: return repr(self) == repr(__o) def __str__(self) -> str: class_name = self.__class__.__name__ # remove Strategy from the name as this causes the gui to have long headlines return class_name.replace("Strategy", "") + f"-{self.sticker_count}-Stickers"