traffic-sign-classifier-robustness-testing / rt_search_based / strategies / jmetalpy_strategies.py
jmetalpy_strategies.py
Raw
from abc import abstractmethod
from typing import List
import torch

from jmetal.algorithm.singleobjective.genetic_algorithm import GeneticAlgorithm
from jmetal.algorithm.singleobjective.local_search import LocalSearch
from jmetal.algorithm.singleobjective.simulated_annealing import SimulatedAnnealing
from jmetal.core.algorithm import Algorithm
from jmetal.core.problem import IntegerProblem
from jmetal.core.solution import IntegerSolution
from jmetal.operator import BinaryTournamentSelection, IntegerPolynomialMutation
from jmetal.operator.crossover import IntegerSBXCrossover
from jmetal.util.termination_criterion import StoppingByEvaluations

from rt_search_based.datasets.datasets import DatasetItem
from rt_search_based.fitness_functions.fitness_functions import FitnessFunction
from rt_search_based.models.classifiers import Classifier
from rt_search_based.strategies.strategies import Strategy
from rt_search_based.transformations import stickers
from rt_search_based.transformations.stickers import Color


class StickerProblem(IntegerProblem):
    """
    Finds the sticker that minimizes the fitness function
    """

    def __init__(
        self,
        dataset_item: DatasetItem,
        classifier: Classifier,
        fitness_function: FitnessFunction,
        sticker_colors: List[Color],
        sticker_count: int = 1,
    ):
        super().__init__()

        self.dataset_item = dataset_item

        self.classifier = classifier
        self.fitness_function = fitness_function
        self.sticker_colors = sticker_colors
        self.sticker_count = sticker_count

        self.number_of_variables = 5 * sticker_count
        self.number_of_objectives = 1
        self.number_of_constraints = 0

        self.obj_directions = [self.MINIMIZE]
        self.obj_lables = ["Coordinates"]

        _, img_height, img_width = self.dataset_item.image.size()

        # Upper bound does not not include value boundary value, or is it? I do not know, jmetalpy docs are wonderful

        self.upper_bound = [
            img_width,
            img_height,
            img_width,
            img_height,
            len(self.sticker_colors),
        ] * self.sticker_count

        # Lower bound value is included

        self.lower_bound = [0] * len(self.upper_bound)

        IntegerSolution.upper_bound = self.upper_bound
        IntegerSolution.lower_bound = self.lower_bound

    def evaluate(self, solution: IntegerSolution) -> IntegerSolution:

        # note: jmetalpy does not respect upper boundaries sometimes
        # Color index is setup as a modulo because jmetalpy will randomly break out of its upper boundary and create a list out of range error
        for i in range(0, len(solution.variables), 5):
            solution.variables[i + 4] %= len(self.sticker_colors)

        # print(f"Current In-Progress Solution Variables: {solution.variables=}")

        sticker = stickers.create_multi_sticker_properties(
            solution.variables, self.sticker_colors
        )

        solution.objectives[0] = self.fitness_function(sticker, self.dataset_item)
        return solution

    def get_name(self) -> str:
        return str(self)


class JMetalPyStrategy(Strategy):
    def __init__(
        self,
        fitness_function: FitnessFunction,
        classifier: Classifier,
        sticker_colors: List[Color] = None,
        max_evaluations: int = 1000,
        sticker_count: int = 1,
        **kwargs,
    ) -> None:
        if self.__class__.__name__ == "JMetalPyStrategy":
            raise TypeError

        super().__init__(
            fitness_function,
            classifier,
            sticker_colors,
            sticker_count=sticker_count,
            **kwargs,
        )
        self.max_evaluations = max_evaluations

    @abstractmethod
    def get_new_algorithm(self, problem: StickerProblem) -> Algorithm:
        """returns the algorithm to be used by the jMetalPy strategy"""

    def search_for_sticker(self, dataset_item: DatasetItem) -> torch.Tensor:
        problem_ = StickerProblem(
            dataset_item,
            self.classifier,
            self.fitness_function,
            self.sticker_colors,
            sticker_count=self.sticker_count,
        )

        algorithm_ = self.get_new_algorithm(problem_)

        algorithm_.run()

        result = algorithm_.get_result()

        solution = stickers.create_multi_sticker_properties(
            result.variables, self.sticker_colors
        )

        # check if solution fools the model
        img = stickers.add_multi_sticker_to_image(solution, dataset_item.image)
        img = img[None, :]
        if self.classifier.get_predicted_class(img) == dataset_item.label:
            return torch.zeros_like(solution) - 1

        return solution


class JMetalPyGeneticStrategy(JMetalPyStrategy):
    def get_new_algorithm(self, problem: StickerProblem) -> Algorithm:
        return GeneticAlgorithm(
            problem=problem,
            population_size=200,
            offspring_population_size=2,
            mutation=IntegerPolynomialMutation(1.0 / problem.number_of_variables, 20.0),
            crossover=IntegerSBXCrossover(0.9, 5.0),
            selection=BinaryTournamentSelection(),
            termination_criterion=StoppingByEvaluations(
                max_evaluations=self.max_evaluations
            ),
        )


class JMetalPyLocalSearchStrategy(JMetalPyStrategy):
    def get_new_algorithm(self, problem: StickerProblem) -> Algorithm:
        return LocalSearch(
            problem=problem,
            mutation=IntegerPolynomialMutation(2.0 / problem.number_of_variables, 50.0),
            termination_criterion=StoppingByEvaluations(
                max_evaluations=self.max_evaluations
            ),
        )


class JMetalPySimulatedAnnealingSearchStrategy(JMetalPyStrategy):
    def get_new_algorithm(self, problem: StickerProblem) -> Algorithm:
        return SimulatedAnnealing(
            problem=problem,
            mutation=IntegerPolynomialMutation(2.0 / problem.number_of_variables, 50.0),
            termination_criterion=StoppingByEvaluations(
                max_evaluations=self.max_evaluations
            ),
        )