traffic-sign-classifier-robustness-testing / tests / test_database_and_csv_manager.py
test_database_and_csv_manager.py
Raw
import itertools as it
from pathlib import Path
from typing import Iterator
import pytest
import torch
from PIL import Image


from rt_search_based.database.csv_manager import CsvManager
from rt_search_based.database.database import Database


@pytest.fixture
def manager_factory(tmp_path):
    assert tmp_path.exists()
    tmp_path = tmp_path.joinpath("test.csv")

    def _manager_factory(field_length: int = 5, dataset_length: int = 5) -> CsvManager:

        fields = [f"field{i}" for i in range(field_length)]
        manager = CsvManager(dataset_length, tmp_path, fields)
        manager.initialize_csv()

        return manager

    return _manager_factory


@pytest.fixture
def database_factory(tmp_path):
    assert tmp_path.exists()
    tmp_path.mkdir(parents=True, exist_ok=True)

    def _database_factory(dataset_length: int = 5):
        while True:
            db = Database(dataset_length, tmp_path).__enter__()
            yield db

            db.__exit__()

    return _database_factory


def assert_csv_manager_file_content(path, fields, values=it.cycle([-1])):
    with open(path, encoding="utf-8") as handle:
        lines: Iterator[str] = iter(handle.readlines())

        assert fields == next(lines).removesuffix("\n").split(",")

        for line in lines:
            stored_values = line.removesuffix("\n").split(",")
            for stored_value, value in zip(stored_values, values):
                assert int(stored_value) == value


def test_csv_manager_init(manager_factory):
    field_length, dataset_length = 5, 5
    csv_manager = manager_factory(field_length, dataset_length)
    assert_csv_manager_file_content(csv_manager.path, csv_manager.fields)


def test_csv_manager_delete(manager_factory):
    csv_manager = manager_factory(1)
    tmp_path = csv_manager.path
    csv_manager.delete()

    assert not tmp_path.exists()


def test_csv_manager_write_data(manager_factory):
    field_length, dataset_length = 5, 5
    csv_manager: CsvManager = manager_factory(field_length, dataset_length)

    csv_manager.read_csv()

    # let the values count up
    csv_manager[:] = torch.arange(field_length * dataset_length).reshape(
        (field_length, dataset_length)
    )

    # assert its all -1
    assert_csv_manager_file_content(csv_manager.path, csv_manager.fields)

    csv_manager.write_csv()

    # assert the values count up
    assert_csv_manager_file_content(
        csv_manager.path, csv_manager.fields, values=it.count(0)
    )


def test_database_new(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    # this one checks wether the default fields are used

    db.new(fitness_function_key, strategy_key, fields)

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert csv_path.exists()

    assert_csv_manager_file_content(csv_path, fields)

    # second and different strategy fitness_function pair and fields
    other_field_length = field_length + 7
    other_fields = [f"another_field{i}" for i in range(other_field_length)]
    fitness_function_key = "AnotherFitnessFunction"
    strategy_key = "AnotherStrategy"

    db.new(fitness_function_key, strategy_key, other_fields)

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert csv_path.exists()

    assert_csv_manager_file_content(csv_path, other_fields)


def test_database_delete(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert csv_path.exists()

    assert (fitness_function_key, strategy_key) in db[:].keys()

    db.delete(fitness_function_key, strategy_key)
    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert not csv_path.exists()


def test_database_purge(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert csv_path.exists()

    assert (fitness_function_key, strategy_key) in db[:].keys()

    db.purge()
    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    assert not csv_path.exists()


def test_set_and_get_item(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    data = torch.arange(field_length * dataset_length).reshape(
        (field_length, dataset_length)
    )

    db[:, fitness_function_key, strategy_key] = data

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    # assert the file is not yet written
    assert_csv_manager_file_content(csv_path, fields)

    assert torch.all(db[:, fitness_function_key, strategy_key] == data)

    db.write_back()

    # assert the file is now written
    assert_csv_manager_file_content(csv_path, fields, it.count(0))


def test_get_set_error(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    with pytest.raises(ValueError) as exc_info:
        db["nonsense"] = "non sense"

    with pytest.raises(ValueError) as exc_info:
        x = db["nonsense"]


def test_get_report(database_factory):
    # this implicitly tests the get_report method
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    data = torch.arange(field_length * dataset_length).reshape(
        (field_length, dataset_length)
    )

    db[:, fitness_function_key, strategy_key] = data

    # only area, runtime, fitness are used for statistics
    data = data[:, :3]
    dict_ = db.get_report(fitness_function_key, strategy_key)
    data = data.double()

    assert all(dict_["Mean Sticker Area"] == torch.mean(data, dim=0))
    assert all(dict_["Sticker Area Variance"] == torch.var(data, dim=0))
    assert all(dict_["Median Sticker Area"] == torch.median(data, dim=0).values)
    assert all(dict_["Sticker Area Mode"] == torch.mode(data, dim=0).values)


def test_get_strategy_and_fitness_function_names(database_factory):
    # this implicitly tests the get_report method
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_keys = [f"FitnessFunction{i}" for i in range(3)]
    strategy_keys = [f"FitnessFunction{i}" for i in range(3)]

    for fitness_key in fitness_keys:
        for strategy_key in strategy_keys:
            db.new(fitness_key, strategy_key, fields)

    for strategy_key in strategy_keys:
        saved_keys = db.get_saved_fitness_function_names(strategy_key)
        assert all([key in saved_keys for key in fitness_keys])

    for fitness_key in fitness_keys:
        saved_keys = db.get_saved_strategy_names(fitness_key)
        assert all([key in saved_keys for key in strategy_keys])

    assert all([key in db.get_saved_strategy_names() for key in fitness_keys])

    assert all([key in db.get_saved_fitness_function_names() for key in strategy_keys])


def test_open_database_fit_existing_files(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    data = torch.arange(field_length * dataset_length).reshape(
        (field_length, dataset_length)
    )

    db[:, fitness_function_key, strategy_key] = data

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    db.write_back()

    # this closes the old database and opens a new one at the same path
    db: Database = next(database_factory(dataset_length))

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")
    # assert the file is now written
    assert_csv_manager_file_content(csv_path, fields, it.count(0))


def test_database_only_write_back_after_update(database_factory):
    field_length, dataset_length = 5, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    # first strategy fitness_function pair
    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    data = torch.arange(field_length * dataset_length).reshape(
        (field_length, dataset_length)
    )

    db[:, fitness_function_key, strategy_key] = data

    csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv")

    csv_manager = db.strategies[fitness_function_key, strategy_key]
    # simulate that the file wasn't changed
    db.strategy_updated[csv_manager] = False
    db.write_back()

    # assert the file was not written
    assert_csv_manager_file_content(csv_path, fields)


def test_get_image_path(database_factory):
    field_length, dataset_length = 10, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    fitness_function_key = "FitnessFunction"
    strategy_key = "Strategy"

    db.new(fitness_function_key, strategy_key, fields)

    # assign stickers that covers the whole image and is black

    db[:, fitness_function_key, strategy_key] = torch.tensor(
        [64**2, 1, 64**2, 0, 0, 64, 64, 42, 42, 42]
    )

    path = db.get_image_path(0, strategy_key, fitness_function_key)

    assert path.exists()
    Image.open(path)

    # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared)
    path.unlink()

    path = db.get_image_path(0, strategy_key, fitness_function_key)

    assert path.exists()
    Image.open(path)

    with pytest.raises(ValueError):
        db.get_image_path(0, strategy_key, None)

    db.get_image_path(0)
    assert path.exists()


def test_get_diagram_path(database_factory):
    field_length, dataset_length = 10, 5
    fields = [f"field{i}" for i in range(field_length)]

    db: Database = next(database_factory(dataset_length))

    fitness_function_key = f"FitnessFunction"
    for i in range(2):
        fitness_function_key = f"FitnessFunction"
        strategy_key = f"Strategy{i}"
        db.new(fitness_function_key, strategy_key, fields)
        for j in range(5):
            db[j, fitness_function_key, strategy_key] = torch.tensor(
                [64**2 - j, 1 + j, 64**2 - j, j, j, 64 - j, 64 - j, 0, 0, 0]
            )

    path = db.get_diagram_image_path(fitness_function_key, "distribution")
    assert path.exists()
    # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared)
    path.unlink()
    path = db.get_diagram_image_path(fitness_function_key, "distribution")
    assert path.exists()

    path = db.get_diagram_image_path(fitness_function_key, "box-plot")
    assert path.exists()
    Image.open(path)
    # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared)
    path.unlink()
    path = db.get_diagram_image_path(fitness_function_key, "box-plot")
    assert path.exists()
    Image.open(path)

    with pytest.raises(ValueError):
        path = db.get_diagram_image_path(None, "distribution")