traffic-sign-classifier-robustness-testing / tests / test_database_consistent.py
test_database_consistent.py
Raw
import torch

from rt_search_based.database import database


def test_database_consistent_area_and_fitness():
    """check whether the optimal solution is always better or equal to every other strategy"""

    with database.Database() as db:
        # this might to be adapted if the addressing in the database changes
        optimal_strategy_key = "HeuristicBruteForceOptimal-1-Stickers"

        for (fitness_key, strategy), data_tensor in db.strategies.items():

            # skip multi sticker strategies as the brute force supports only single sticker
            if "2-Stickers" in str(strategy):
                continue

            # get the correct optimal solution tensor for the fitnessfunction with the corresponding key
            area_tensor_optimal = db[:, fitness_key, optimal_strategy_key]

            # every fitnessvalue (stored at index 0) has to be equal or worse than the optimal solution
            assert torch.all(
                (data_tensor[:, 2] >= area_tensor_optimal[:, 0])
                | (data_tensor[:, 2] == -1)
            ).item()

            # every area (stored at index 2) has to be equal or worse than the optimal solution
            assert torch.all(
                (data_tensor[:, 0] >= area_tensor_optimal[:, 2])
                | (data_tensor[:, 0] == -1)
            ).item()