traffic-sign-classifier-robustness-testing / tests / test_fitness_functions.py
test_fitness_functions.py
Raw
from math import ceil
from typing import cast

import numpy as np
import numpy.linalg as LA
import pytest
import torch
from pytest import approx
from rt_search_based.datasets.datasets import DatasetItem, MiniGtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import (
    AreaFitnessFunction,
    BasicAreaFitnessFunction,
    FitnessFunction,
    FooledAreaFitnessFunction,
    PenalizationFitnessFunction,
    PositionFitnessFunction,
    RatioFitnessFunction,
)
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.transformations import stickers
from rt_search_based.transformations.stickers import Color, add_multi_sticker_to_image


class TestFitnessfunction:
    classifier = Classifier2Test_1()
    dataset = MiniGtsrbDataset(1)
    dataset_item = dataset[0]
    # whole img covered
    sticker_fooled = torch.tensor([64**2, 0, 0, 63, 63, 0, 0, 0])

    # only most top left pixel covered
    sticker_not_fooled = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0])

    _, img_height, img_width = dataset_item.image.size()
    img_area = img_height * img_width

    def test_area_fitness_function(self):
        fitness_function = AreaFitnessFunction(self.classifier)
        assert (
            fitness_function(self.sticker_fooled, self.dataset_item)
            == self.sticker_fooled[0].item()
        )

    def test_fooled_area_fitness_function(self):

        fitness_function = FooledAreaFitnessFunction(self.classifier)
        # Scenario 1 - classifier is fooled
        assert (
            fitness_function(self.sticker_fooled, self.dataset_item)
            == self.sticker_fooled[0].item()
        )

        # Scenario 2 - classifier is not fooled
        assert (
            fitness_function(self.sticker_not_fooled, self.dataset_item)
            == self.img_area + 100
        )

    def test_basic_area_fitness_function(self):
        fitness_function = BasicAreaFitnessFunction(self.classifier)
        # Scenario 1 - classifier is fooled
        assert (
            fitness_function(self.sticker_fooled, self.dataset_item)
            == self.sticker_fooled[0].item()
        )

        # Scenario 2 - classifier is not fooled

        assert (
            fitness_function(self.sticker_not_fooled, self.dataset_item)
            == self.img_area + self.sticker_not_fooled[0].item()
        )

    def test_penalization_fitness_function(self):
        fitness_function = PenalizationFitnessFunction(self.classifier)
        # Scenario 1 - classifier is fooled
        assert (
            fitness_function(self.sticker_fooled, self.dataset_item)
            == self.sticker_fooled[0].item()
        )

        # Scenario 2 - classifier is not fooled

        img = add_multi_sticker_to_image(
            self.sticker_not_fooled, self.dataset_item.image
        )
        img = img[None, :]
        img = img.to(self.sticker_not_fooled.device)

        prob_dict = self.classifier.get_class_probabilities(img)
        p1 = prob_dict.pop(self.dataset_item.label)
        p2 = max(prob_dict.values())

        if p1 == approx(1.0):
            assert (
                fitness_function(self.sticker_not_fooled, self.dataset_item)
                == self.img_area
                + int((p1 - p2) * 100)
                - self.sticker_not_fooled[0].item()
            )
        else:
            assert fitness_function(
                self.sticker_not_fooled, self.dataset_item
            ) == self.img_area + int((p1 - p2) * 100)

    def test_position_fitness_function(self):
        fitness_function = PositionFitnessFunction(self.classifier)

        image_center = np.array([ceil(self.img_height / 2), ceil(self.img_width / 2)])
        image_corner = np.zeros(2)

        _, x1, y1, x2, y2, _, _, _ = self.sticker_fooled

        sticker_center = np.array([ceil((x2 - x1) / 2) + x1, ceil((y2 - y1) / 2) + y1])

        max_dist_from_center = int(LA.norm(image_center - image_corner))
        sticker_dist_from_center = int(LA.norm(sticker_center - image_center))

        alpha = max_dist_from_center + 1
        fitness_value = alpha * self.sticker_fooled[0].item() + sticker_dist_from_center

        assert fitness_function(self.sticker_fooled, self.dataset_item) == fitness_value

    def test_ratio_fitness_function(self):
        fitness_function = RatioFitnessFunction(self.classifier)

        sticker_area = self.sticker_fooled[0].item()
        sticker_width = self.sticker_fooled[3].item() - self.sticker_fooled[1].item()
        sticker_height = self.sticker_fooled[4].item() - self.sticker_fooled[2].item()

        ratio = sticker_width / sticker_height

        if ratio == approx(1.0):
            assert (
                fitness_function(self.sticker_fooled, self.dataset_item) == sticker_area
            )
        else:
            assert fitness_function(
                self.sticker_fooled, self.dataset_item
            ) == pytest.approx(sticker_area / abs(1 - ratio))


def test_fitnessfunction_abstract():
    with pytest.raises(TypeError):
        FitnessFunction(classifier=None)


def test_visualize_fitness_function():
    classifier = Classifier2Test_1()
    fitness_function = PenalizationFitnessFunction(classifier)

    dataset = MiniGtsrbDataset(1)
    dataset_item = cast(DatasetItem, dataset[0])
    fitness_function.visualize_fitness_function_for_image(
        dataset_item, [Color.BLACK], n=1000
    )