traffic-sign-classifier-robustness-testing / tests / test_strategies.py
test_strategies.py
Raw
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch

from tests.test_database_and_csv_manager import database_factory
from rt_search_based.database.database import Database
from rt_search_based.datasets.datasets import MiniGtsrbDataset
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.fitness_functions.fitness_functions import (
    PenalizationFitnessFunction,
)
from rt_search_based.strategies.heuristic_brute_force_strategy import (
    HeuristicBruteForceStrategy,
    HeuristicBruteForceStrategyCenter,
    HeuristicBruteForceStrategyOptimal,
)

from rt_search_based.strategies.pymoo_strategies import PyMooGeneticStrategy
from rt_search_based.strategies.strategies import Strategy
from rt_search_based.strategies.jmetalpy_strategies import (
    JMetalPyGeneticStrategy,
    JMetalPyLocalSearchStrategy,
    JMetalPySimulatedAnnealingSearchStrategy,
    JMetalPyStrategy,
)


@pytest.fixture
def strategy_test_runner(database_factory):
    """
    returns a function that can execute a Strategy
    """

    def runner(strategy_class: Strategy, need_constraints=False):
        """
        running this can take quite some time as the strategies are rather complex constructs
        """
        classifier = Classifier2Test_1()
        fitness_function = PenalizationFitnessFunction(classifier)
        dataset = MiniGtsrbDataset(13)
        # use the 13th img as it has a small minimal solution (therefore faster execution in the pipeline)
        dataset_item = dataset[12]

        _, image_height, image_width = dataset_item.image.size()
        image_area = image_height * image_width

        database: Database = next(database_factory(dataset_length=len(dataset)))

        strategy = strategy_class(
            fitness_function,
            classifier,
            need_constraints=need_constraints,
        )

        # ensure everything happens on cpu
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # for heuristic search use gpu if available
        if strategy_class in [
            HeuristicBruteForceStrategyOptimal,
            HeuristicBruteForceStrategy,
        ]:
            strategy.classifier.classifier.to(device)

        solution = strategy.search_for_sticker(dataset_item)
        area, x1, y1, x2, y2, *_ = solution.tolist()

        assert x2 >= x1
        assert y2 >= y1
        assert area <= image_area

        strategy.search_for_sticker = MagicMock(side_effect=strategy.search_for_sticker)

        # only evaluate last image
        strategy.evaluate(dataset[-1:], database)
        strategy.search_for_sticker.assert_called_once_with(dataset_item)

    return runner


def test_jmetalpystrategy_abstract():
    with pytest.raises(TypeError):
        JMetalPyStrategy(None, None)


def test_pymoo_genetic_strategy(strategy_test_runner):
    # doesn't actually need constraints for the tested fitness function
    # but this tests them along the way
    strategy_test_runner(PyMooGeneticStrategy, need_constraints=True)


def test_jmetalpy_genetic_strategy(strategy_test_runner):
    strategy_test_runner(JMetalPyGeneticStrategy)


def test_jmetalpy_genetic_strategy(strategy_test_runner):
    strategy_test_runner(JMetalPyLocalSearchStrategy)


def test_jmetalpy_simulated_annealing_strategy(strategy_test_runner):
    strategy_test_runner(JMetalPySimulatedAnnealingSearchStrategy)


def test_heuristic_brute_force_strategy_center(strategy_test_runner):
    # for this test: test a cpu only version
    mock_false = Mock(return_value=False)
    with patch(f"torch.cuda.is_available", mock_false):
        strategy_test_runner(HeuristicBruteForceStrategyCenter)


def test_heuristic_brute_force_strategy(strategy_test_runner):
    strategy_test_runner(HeuristicBruteForceStrategy)


def test_heuristic_brute_force_strategy_optimal(strategy_test_runner):
    strategy_test_runner(HeuristicBruteForceStrategyOptimal)