traffic-sign-classifier-robustness-testing / rt_search_based / strategies / strategies.py
strategies.py
Raw
import logging
import time
from abc import ABC, abstractmethod
from typing import List

import torch
from rt_search_based.datasets.datasets import DatasetItem
from rt_search_based.fitness_functions.fitness_functions import FitnessFunction
from rt_search_based.models.classifiers import Classifier
from rt_search_based.transformations.stickers import Color


class Strategy(ABC):
    """
    A Strategy is the abstract concept of a search algorithm that
    can search an image for the minimal sticker(s) that fools a DL classifier.
    For this it can utilize a Fitnessfunction thats helps to guide the search.
    """

    def __init__(
        self,
        fitness_function: FitnessFunction,
        classifier: Classifier,
        sticker_colors: List[Color] = None,
        sticker_count=1,
        **kwargs,
    ) -> None:

        self.fitness_function = fitness_function
        self.classifier = classifier
        self.sticker_colors = (
            [Color.BLACK] if sticker_colors is None else sticker_colors
        )
        self.sticker_count = sticker_count
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Used by repr function to recreate strategy
        self.init_args = {
            "fitness_function": repr(self.fitness_function),
            "classifier": repr(self.classifier),
            "sticker_colors": repr(self.sticker_colors),
        }
        self.init_args.update({key: repr(value) for key, value in kwargs.items()})

        # specify the column descriptors later used in the csv files in the database
        self.fields: List[str] = ["fitness", "runtime", "area"]
        # add properties column for each sticker
        for i in range(self.sticker_count):
            self.fields.extend(
                [
                    f"x1_{i}",
                    f"y1_{i}",
                    f"x2_{i}",
                    f"y2_{i}",
                    f"r_{i}",
                    f"g_{i}",
                    f"b_{i}",
                ]
            )

    @abstractmethod
    def search_for_sticker(self, dataset_item: DatasetItem) -> torch.Tensor:
        """
        Finds optimal stickers for the given dataset_item and returns the properties of the minimal sticker
        """

    def evaluate(self, dataset, database):
        """
        Perform the search strategies for each dataset_item in the dataset.
        Writes values to database if they are better (smaller area) than the previously stored ones.
        """
        with database as db:
            key = str(self.fitness_function), str(self)

            # add new csv_manager to database if not already present
            if key not in db.strategies.keys():
                db.new(self.fitness_function, self)

            for dataset_item in dataset:
                start = time.perf_counter_ns()
                sticker_properties = self.search_for_sticker(dataset_item)

                runtime = time.perf_counter_ns() - start

                current = db[dataset_item.index, self.fitness_function, self]

                # at index 3 the area is stored (only store if no data is present == -1 or the old solution was worse)
                if sticker_properties[0] != -1 and (
                    current[3] == -1 or current[3] > sticker_properties[0]
                ):
                    new_record = self.initialize_record(
                        dataset_item, runtime, sticker_properties
                    )
                    db[
                        dataset_item.index,
                        str(self.fitness_function),
                        str(self),
                    ] = new_record
                    db.write_back()

                else:
                    logging.info(
                        f"old value was already better or equally good - {current[3]=} {sticker_properties[0]=}"
                    )

    def initialize_record(self, dataset_item, runtime, sticker_properties):
        """
        Creates record to be saved in database of the structure
        [fitness_value, runtime, area, *[x1, y1, x2, y2, r, g, b]*sticker_count]
        """
        return torch.cat(
            [
                torch.tensor(
                    [
                        self.fitness_function(sticker_properties, dataset_item),
                        runtime,
                    ]
                ),
                # Need to make sure all tensors are on the same device before concatinating
                sticker_properties.to(device="cpu"),
            ]
        )

    def visualize_convergence(self):
        """
        Displays visual of how the strategy converges
        """

    def __repr__(self) -> str:
        return f"{type(self).__name__}({', '.join([f'{key}={value}' for key, value in self.init_args.items()])})"

    def __eq__(self, __o: object) -> bool:
        return repr(self) == repr(__o)

    def __str__(self) -> str:
        class_name = self.__class__.__name__
        # remove Strategy from the name as this causes the gui to have long headlines
        return class_name.replace("Strategy", "") + f"-{self.sticker_count}-Stickers"