from math import ceil from typing import cast import numpy as np import numpy.linalg as LA import pytest import torch from pytest import approx from rt_search_based.datasets.datasets import DatasetItem, MiniGtsrbDataset from rt_search_based.fitness_functions.fitness_functions import ( AreaFitnessFunction, BasicAreaFitnessFunction, FitnessFunction, FooledAreaFitnessFunction, PenalizationFitnessFunction, PositionFitnessFunction, RatioFitnessFunction, ) from rt_search_based.models.classifiers import Classifier2Test_1 from rt_search_based.transformations import stickers from rt_search_based.transformations.stickers import Color, add_multi_sticker_to_image class TestFitnessfunction: classifier = Classifier2Test_1() dataset = MiniGtsrbDataset(1) dataset_item = dataset[0] # whole img covered sticker_fooled = torch.tensor([64**2, 0, 0, 63, 63, 0, 0, 0]) # only most top left pixel covered sticker_not_fooled = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0]) _, img_height, img_width = dataset_item.image.size() img_area = img_height * img_width def test_area_fitness_function(self): fitness_function = AreaFitnessFunction(self.classifier) assert ( fitness_function(self.sticker_fooled, self.dataset_item) == self.sticker_fooled[0].item() ) def test_fooled_area_fitness_function(self): fitness_function = FooledAreaFitnessFunction(self.classifier) # Scenario 1 - classifier is fooled assert ( fitness_function(self.sticker_fooled, self.dataset_item) == self.sticker_fooled[0].item() ) # Scenario 2 - classifier is not fooled assert ( fitness_function(self.sticker_not_fooled, self.dataset_item) == self.img_area + 100 ) def test_basic_area_fitness_function(self): fitness_function = BasicAreaFitnessFunction(self.classifier) # Scenario 1 - classifier is fooled assert ( fitness_function(self.sticker_fooled, self.dataset_item) == self.sticker_fooled[0].item() ) # Scenario 2 - classifier is not fooled assert ( fitness_function(self.sticker_not_fooled, self.dataset_item) == self.img_area + self.sticker_not_fooled[0].item() ) def test_penalization_fitness_function(self): fitness_function = PenalizationFitnessFunction(self.classifier) # Scenario 1 - classifier is fooled assert ( fitness_function(self.sticker_fooled, self.dataset_item) == self.sticker_fooled[0].item() ) # Scenario 2 - classifier is not fooled img = add_multi_sticker_to_image( self.sticker_not_fooled, self.dataset_item.image ) img = img[None, :] img = img.to(self.sticker_not_fooled.device) prob_dict = self.classifier.get_class_probabilities(img) p1 = prob_dict.pop(self.dataset_item.label) p2 = max(prob_dict.values()) if p1 == approx(1.0): assert ( fitness_function(self.sticker_not_fooled, self.dataset_item) == self.img_area + int((p1 - p2) * 100) - self.sticker_not_fooled[0].item() ) else: assert fitness_function( self.sticker_not_fooled, self.dataset_item ) == self.img_area + int((p1 - p2) * 100) def test_position_fitness_function(self): fitness_function = PositionFitnessFunction(self.classifier) image_center = np.array([ceil(self.img_height / 2), ceil(self.img_width / 2)]) image_corner = np.zeros(2) _, x1, y1, x2, y2, _, _, _ = self.sticker_fooled sticker_center = np.array([ceil((x2 - x1) / 2) + x1, ceil((y2 - y1) / 2) + y1]) max_dist_from_center = int(LA.norm(image_center - image_corner)) sticker_dist_from_center = int(LA.norm(sticker_center - image_center)) alpha = max_dist_from_center + 1 fitness_value = alpha * self.sticker_fooled[0].item() + sticker_dist_from_center assert fitness_function(self.sticker_fooled, self.dataset_item) == fitness_value def test_ratio_fitness_function(self): fitness_function = RatioFitnessFunction(self.classifier) sticker_area = self.sticker_fooled[0].item() sticker_width = self.sticker_fooled[3].item() - self.sticker_fooled[1].item() sticker_height = self.sticker_fooled[4].item() - self.sticker_fooled[2].item() ratio = sticker_width / sticker_height if ratio == approx(1.0): assert ( fitness_function(self.sticker_fooled, self.dataset_item) == sticker_area ) else: assert fitness_function( self.sticker_fooled, self.dataset_item ) == pytest.approx(sticker_area / abs(1 - ratio)) def test_fitnessfunction_abstract(): with pytest.raises(TypeError): FitnessFunction(classifier=None) def test_visualize_fitness_function(): classifier = Classifier2Test_1() fitness_function = PenalizationFitnessFunction(classifier) dataset = MiniGtsrbDataset(1) dataset_item = cast(DatasetItem, dataset[0]) fitness_function.visualize_fitness_function_for_image( dataset_item, [Color.BLACK], n=1000 )