traffic-sign-classifier-robustness-testing / tests / test_rt_search_based.py
test_rt_search_based.py
Raw
from unittest import TestCase

import pytest
from rt_search_based import __version__
from rt_search_based.datasets.datasets import MiniGtsrbDataset
from rt_search_based.fitness_functions.fitness_functions import BasicAreaFitnessFunction
from rt_search_based.models.classifiers import Classifier2Test_1
from rt_search_based.strategies.jmetalpy_strategies import JMetalPyGeneticStrategy
from rt_search_based.strategies.pymoo_strategies import PyMooGeneticStrategy
from rt_search_based.transformations import stickers, norms_and_colors


def test_version():
    assert __version__ == "0.1.0"


# Strategy class tests


def test_jmetalpy_genetic_strategy_repr():
    temp_classifier = Classifier2Test_1()
    temp_fitness_function = BasicAreaFitnessFunction(temp_classifier)
    temp_sticker_colors = [stickers.Color.BLACK, stickers.Color.WHITE]
    temp_jmetalpy_genetic_strategy = JMetalPyGeneticStrategy(
        temp_fitness_function, temp_classifier, temp_sticker_colors, 10
    )

    assert temp_jmetalpy_genetic_strategy == eval(repr(temp_jmetalpy_genetic_strategy))


def test_pymoo_genetic_strategy_repr():
    temp_classifier = Classifier2Test_1()
    temp_fitness_function = BasicAreaFitnessFunction(temp_classifier)
    temp_sticker_colors = [stickers.Color.BLACK, stickers.Color.WHITE]
    temp_pymoo_genetic_strategy = PyMooGeneticStrategy(
        temp_fitness_function,
        temp_classifier,
        temp_sticker_colors,
    )

    assert temp_pymoo_genetic_strategy == eval(repr(temp_pymoo_genetic_strategy))


def test_center_translator():
    tuple_ = (16, 16, 16, 16)
    assert norms_and_colors.center_translator(tuple_) == (16, 16, 48, 48)


def test_dataset():
    dataset = MiniGtsrbDataset(10)

    list_ = dataset.dataset

    dataset.randomize(1)

    shuffled_list = dataset.dataset

    assert len(list_) == len(shuffled_list)

    TestCase().assertCountEqual(list_, shuffled_list)

    assert repr(dataset) == f"{type(dataset).__name__}({dataset.length})"

    with pytest.raises(ValueError):
        # too big
        MiniGtsrbDataset(99999)