traffic-sign-classifier-robustness-testing / rt_search_based / fitness_functions / fitness_functions.py
fitness_functions.py
Raw
from abc import ABC
from math import ceil
from typing import Optional, Union

import numpy as np
import numpy.linalg as LA
import rt_search_based.search_algorithms.heuristic_brute_force as search
import torch
from matplotlib import pyplot as plt
from pytest import approx
from rt_search_based.database.database import Database
from rt_search_based.datasets.datasets import DatasetItem
from rt_search_based.models.classifiers import Classifier
from rt_search_based.transformations import stickers
from tqdm import tqdm


class FitnessFunction(ABC):
    def __init__(self, classifier: Classifier) -> None:
        if self.__class__.__name__ == "FitnessFunction":
            raise TypeError
        self.classifier = classifier
        self.search_space: Optional[torch.Tensor] = None

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:
        """
        apply fitness function for the sticker_props to the dataset_item and return the fitness value

        Args:
            sticker_props (torch.Tensor): sticker properties like [area, *[x1, y1, x2, y2, r, g, b] * sticker_count]
            dataset_item (DatasetItem): DatasetItem that stores the image, truth_label and index within the dataset

        Returns:
            Union[int, float]: fitness value
        """

    def visualize_fitness_function_for_image(
        self, dataset_item: DatasetItem, sticker_colors, n=10
    ):
        """
        Displays how different best performing widths and best performing heights (and sticker color?) perform on the fitness function
        """
        # TODO This is still experimental so integrating this into gui etc is part of future work

        # Define search space in which all stickers are evaluated and later averaged
        if self.search_space is None:
            self.search_space = search.gen_search_space(
                64,
                20,
                lambda x: 0,
                # use all colors
                sticker_colors,
                batch_size=1,
            )
            self.search_space = self.search_space.to("cpu")

        _, image_height, image_width = dataset_item.image.size()

        # create tensors to store our results
        sum_fitness_values = torch.zeros((image_height, image_width), dtype=torch.int64)
        counter = torch.zeros_like(sum_fitness_values)

        skip_counter = 0
        for sticker_props in tqdm(self.search_space):
            # only sample every n-th value
            skip_counter += 1
            if skip_counter % n != 0:
                continue

            # remove an additional dimension (unpack batch of batch_size 1)
            sticker_props = torch.flatten(sticker_props)
            fitness_value = self(sticker_props, dataset_item)

            # add the fitness_value within each pixel the sticker covers (loop if for multi stickers)
            for i in range(1, len(sticker_props), 7):
                x1, y1, x2, y2, _, _, _ = sticker_props[i : i + 7]
                sum_fitness_values[x1 : x2 + 1, y1 : y2 + 1] += fitness_value
                counter[x1 : x2 + 1, y1 : y2 + 1] += 1

        # average the pixels
        avg_fitness_values = sum_fitness_values / counter

        _, axis = plt.subplots(1, 3, figsize=(15, 5), dpi=300)

        # get optimal solution from database for the final plot
        with Database() as db:
            strat_name = "HeuristicBruteForceOptimal-1-Stickers"
            # extract the propeties of the optimal sticker
            sticker_props = db[dataset_item.index, self, strat_name][2:]
            # generate image with sticker on it
            img = stickers.add_multi_sticker_to_image(sticker_props, dataset_item.image)
            # rearrange dimensions for pyplotlib
            img = img.permute(1, 2, 0)
            axis[1].imshow(img)
            axis[2].imshow(img)

        # heatmap
        axis[0].imshow(avg_fitness_values, cmap="viridis", interpolation="none")

        # heatmap with img below
        axis[1].imshow(
            avg_fitness_values, cmap="viridis", alpha=0.75, interpolation="none"
        )

        plt.savefig(f"./rt_search_based/imgs/{dataset_item.index}_{self}.png")

    def __str__(self) -> str:
        return f"{type(self).__name__}"

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self.classifier})"


class AreaFitnessFunction(FitnessFunction):
    """
    FitnessFunction that simply returns the area
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:
        return sticker_props[0].item()


class FooledAreaFitnessFunction(FitnessFunction):
    """
    FitnessFunction that returns the sticker_area if model is fooled and a large value (bigger than any possible sticker)
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:

        sticker_area: int = int(sticker_props[0].item())

        _, image_height, image_width = dataset_item.image.size()
        image_area = image_height * image_width

        img = stickers.add_multi_sticker_to_image(sticker_props, dataset_item.image)

        # Add extra dimension used for classifier (equivalent to batch_size 1)
        img = torch.unsqueeze(img, dim=0)

        img = img.to(sticker_props.device)

        fools_classifier = dataset_item.label != self.classifier.get_predicted_class(
            img
        )

        if fools_classifier:
            fitness_value = sticker_area
        else:
            fitness_value = image_area + 100  # larger than the largest sticker

        return fitness_value


class BasicAreaFitnessFunction(FitnessFunction):
    """
    FitnessFunction that simply returns the sticker_area if model is fooled and 2 * sticker_area if not
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:

        sticker_area: int = int(sticker_props[0].item())

        _, image_height, image_width = dataset_item.image.size()
        image_area = image_height * image_width

        img = stickers.add_multi_sticker_to_image(sticker_props, dataset_item.image)

        # Add extra dimension used for classifier (equivalent to batch_size 1)
        img = img[None, :]

        img = img.to(sticker_props.device)

        fools_classifier = dataset_item.label != self.classifier.get_predicted_class(
            img
        )

        fitness_value = int(not fools_classifier) * image_area + sticker_area

        return fitness_value


class PenalizationFitnessFunction(FitnessFunction):
    """
    FitnessFunction that returns sticker_area if classifier is fooled and else compute a fitness_value that considers the
    predicted probability for the truth label and the confusion label (the one with highest probability)
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:

        sticker_area: int = int(sticker_props[0].item())

        _, image_height, image_width = dataset_item.image.size()
        image_area = image_height * image_width

        img = stickers.add_multi_sticker_to_image(sticker_props, dataset_item.image)

        # Add extra dimension used for classifier (equivalent to batch_size 1)
        img = torch.unsqueeze(img, dim=0)
        img = img.to(sticker_props.device)

        fools_classifier = dataset_item.label != self.classifier.get_predicted_class(
            img
        )

        prob_dict = self.classifier.get_class_probabilities(img)

        p1 = prob_dict.pop(dataset_item.label)
        p2 = max(prob_dict.values())
        max_area = image_area

        if fools_classifier:
            fitness_value = sticker_area
        else:
            if p1 == approx(1.0):
                fitness_value = max_area + int((p1 - p2) * 100) - sticker_area
            else:
                fitness_value = max_area + int((p1 - p2) * 100)

        return fitness_value


class PositionFitnessFunction(FitnessFunction):
    """
    FitnessFunction that takes into account the sticker area and the position of the sticker.
    The alpha parameter is used to control the importance of the area and the position.

    In its curent form, alpha is defined so that between stickers of different sizes,
    the one with the bigger area has a higher fitness value, but if stickers have the same size,
    the one closest to the center of the image has a higher fitness value.
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:

        max_dist_from_center = 0
        sticker_dist_from_center = 0

        _, image_height, image_width = dataset_item.image.size()
        image_center = np.array([ceil(image_height / 2), ceil(image_width / 2)])
        image_corner = np.zeros(2)

        for i in range(1, len(sticker_props), 7):
            x1, y1, x2, y2, *_ = sticker_props[i : i + 7]

            sticker_center = np.array(
                [ceil((x2 - x1) / 2) + x1, ceil((y2 - y1) / 2) + y1]
            )

            max_dist_from_center += int(LA.norm(image_center - image_corner))

            sticker_dist_from_center += int(LA.norm(sticker_center - image_center))

        # The length of the sticker_props // 7 is equal to the sticker count
        avg_max_dist_from_center = ceil(
            max_dist_from_center / (len(sticker_props) // 7)
        )
        avg_sticker_dist_from_center = ceil(
            sticker_dist_from_center / (len(sticker_props) // 7)
        )

        # The alpha has to remain greater than the max distance from the center
        # for the fitness value to be saturated by the sticker area

        alpha = avg_max_dist_from_center + 1

        sticker_area = sticker_props[0].item()

        fitness_value = alpha * sticker_area + avg_sticker_dist_from_center

        return int(fitness_value)


class RatioFitnessFunction(FitnessFunction):
    """
    FitnessFunction that takes into account the sticker area and the ratio of the sticker dimensions.
    It favors stickers whose shape is closer to a square.
    """

    def __call__(
        self, sticker_props: torch.Tensor, dataset_item: DatasetItem
    ) -> Union[int, float]:
        sticker_area = sticker_props[0].item()
        ratio = 0
        for i in range(1, len(sticker_props), 7):
            x1, y1, x2, y2, *_ = sticker_props[i : i + 7]

            sticker_width = x2 - x1 + 1
            sticker_height = y2 - y1 + 1
            ratio += sticker_width / sticker_height

        # The length of the sticker_props // 7 is equal to the sticker count
        avg_ratio = ratio / (len(sticker_props) // 7)
        if avg_ratio == 1.0:
            fitness_value = sticker_area
        else:
            fitness_value = sticker_area / abs(1 - avg_ratio)

        return fitness_value