import itertools as it from pathlib import Path from typing import Iterator import pytest import torch from PIL import Image from rt_search_based.database.csv_manager import CsvManager from rt_search_based.database.database import Database @pytest.fixture def manager_factory(tmp_path): assert tmp_path.exists() tmp_path = tmp_path.joinpath("test.csv") def _manager_factory(field_length: int = 5, dataset_length: int = 5) -> CsvManager: fields = [f"field{i}" for i in range(field_length)] manager = CsvManager(dataset_length, tmp_path, fields) manager.initialize_csv() return manager return _manager_factory @pytest.fixture def database_factory(tmp_path): assert tmp_path.exists() tmp_path.mkdir(parents=True, exist_ok=True) def _database_factory(dataset_length: int = 5): while True: db = Database(dataset_length, tmp_path).__enter__() yield db db.__exit__() return _database_factory def assert_csv_manager_file_content(path, fields, values=it.cycle([-1])): with open(path, encoding="utf-8") as handle: lines: Iterator[str] = iter(handle.readlines()) assert fields == next(lines).removesuffix("\n").split(",") for line in lines: stored_values = line.removesuffix("\n").split(",") for stored_value, value in zip(stored_values, values): assert int(stored_value) == value def test_csv_manager_init(manager_factory): field_length, dataset_length = 5, 5 csv_manager = manager_factory(field_length, dataset_length) assert_csv_manager_file_content(csv_manager.path, csv_manager.fields) def test_csv_manager_delete(manager_factory): csv_manager = manager_factory(1) tmp_path = csv_manager.path csv_manager.delete() assert not tmp_path.exists() def test_csv_manager_write_data(manager_factory): field_length, dataset_length = 5, 5 csv_manager: CsvManager = manager_factory(field_length, dataset_length) csv_manager.read_csv() # let the values count up csv_manager[:] = torch.arange(field_length * dataset_length).reshape( (field_length, dataset_length) ) # assert its all -1 assert_csv_manager_file_content(csv_manager.path, csv_manager.fields) csv_manager.write_csv() # assert the values count up assert_csv_manager_file_content( csv_manager.path, csv_manager.fields, values=it.count(0) ) def test_database_new(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" # this one checks wether the default fields are used db.new(fitness_function_key, strategy_key, fields) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert csv_path.exists() assert_csv_manager_file_content(csv_path, fields) # second and different strategy fitness_function pair and fields other_field_length = field_length + 7 other_fields = [f"another_field{i}" for i in range(other_field_length)] fitness_function_key = "AnotherFitnessFunction" strategy_key = "AnotherStrategy" db.new(fitness_function_key, strategy_key, other_fields) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert csv_path.exists() assert_csv_manager_file_content(csv_path, other_fields) def test_database_delete(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert csv_path.exists() assert (fitness_function_key, strategy_key) in db[:].keys() db.delete(fitness_function_key, strategy_key) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert not csv_path.exists() def test_database_purge(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert csv_path.exists() assert (fitness_function_key, strategy_key) in db[:].keys() db.purge() csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") assert not csv_path.exists() def test_set_and_get_item(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) data = torch.arange(field_length * dataset_length).reshape( (field_length, dataset_length) ) db[:, fitness_function_key, strategy_key] = data csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") # assert the file is not yet written assert_csv_manager_file_content(csv_path, fields) assert torch.all(db[:, fitness_function_key, strategy_key] == data) db.write_back() # assert the file is now written assert_csv_manager_file_content(csv_path, fields, it.count(0)) def test_get_set_error(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) with pytest.raises(ValueError) as exc_info: db["nonsense"] = "non sense" with pytest.raises(ValueError) as exc_info: x = db["nonsense"] def test_get_report(database_factory): # this implicitly tests the get_report method field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) data = torch.arange(field_length * dataset_length).reshape( (field_length, dataset_length) ) db[:, fitness_function_key, strategy_key] = data # only area, runtime, fitness are used for statistics data = data[:, :3] dict_ = db.get_report(fitness_function_key, strategy_key) data = data.double() assert all(dict_["Mean Sticker Area"] == torch.mean(data, dim=0)) assert all(dict_["Sticker Area Variance"] == torch.var(data, dim=0)) assert all(dict_["Median Sticker Area"] == torch.median(data, dim=0).values) assert all(dict_["Sticker Area Mode"] == torch.mode(data, dim=0).values) def test_get_strategy_and_fitness_function_names(database_factory): # this implicitly tests the get_report method field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_keys = [f"FitnessFunction{i}" for i in range(3)] strategy_keys = [f"FitnessFunction{i}" for i in range(3)] for fitness_key in fitness_keys: for strategy_key in strategy_keys: db.new(fitness_key, strategy_key, fields) for strategy_key in strategy_keys: saved_keys = db.get_saved_fitness_function_names(strategy_key) assert all([key in saved_keys for key in fitness_keys]) for fitness_key in fitness_keys: saved_keys = db.get_saved_strategy_names(fitness_key) assert all([key in saved_keys for key in strategy_keys]) assert all([key in db.get_saved_strategy_names() for key in fitness_keys]) assert all([key in db.get_saved_fitness_function_names() for key in strategy_keys]) def test_open_database_fit_existing_files(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) data = torch.arange(field_length * dataset_length).reshape( (field_length, dataset_length) ) db[:, fitness_function_key, strategy_key] = data csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") db.write_back() # this closes the old database and opens a new one at the same path db: Database = next(database_factory(dataset_length)) csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") # assert the file is now written assert_csv_manager_file_content(csv_path, fields, it.count(0)) def test_database_only_write_back_after_update(database_factory): field_length, dataset_length = 5, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) # first strategy fitness_function pair fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) data = torch.arange(field_length * dataset_length).reshape( (field_length, dataset_length) ) db[:, fitness_function_key, strategy_key] = data csv_path = Path(db.root_folder, fitness_function_key, f"{strategy_key}.csv") csv_manager = db.strategies[fitness_function_key, strategy_key] # simulate that the file wasn't changed db.strategy_updated[csv_manager] = False db.write_back() # assert the file was not written assert_csv_manager_file_content(csv_path, fields) def test_get_image_path(database_factory): field_length, dataset_length = 10, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) fitness_function_key = "FitnessFunction" strategy_key = "Strategy" db.new(fitness_function_key, strategy_key, fields) # assign stickers that covers the whole image and is black db[:, fitness_function_key, strategy_key] = torch.tensor( [64**2, 1, 64**2, 0, 0, 64, 64, 42, 42, 42] ) path = db.get_image_path(0, strategy_key, fitness_function_key) assert path.exists() Image.open(path) # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared) path.unlink() path = db.get_image_path(0, strategy_key, fitness_function_key) assert path.exists() Image.open(path) with pytest.raises(ValueError): db.get_image_path(0, strategy_key, None) db.get_image_path(0) assert path.exists() def test_get_diagram_path(database_factory): field_length, dataset_length = 10, 5 fields = [f"field{i}" for i in range(field_length)] db: Database = next(database_factory(dataset_length)) fitness_function_key = f"FitnessFunction" for i in range(2): fitness_function_key = f"FitnessFunction" strategy_key = f"Strategy{i}" db.new(fitness_function_key, strategy_key, fields) for j in range(5): db[j, fitness_function_key, strategy_key] = torch.tensor( [64**2 - j, 1 + j, 64**2 - j, j, j, 64 - j, 64 - j, 0, 0, 0] ) path = db.get_diagram_image_path(fitness_function_key, "distribution") assert path.exists() # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared) path.unlink() path = db.get_diagram_image_path(fitness_function_key, "distribution") assert path.exists() path = db.get_diagram_image_path(fitness_function_key, "box-plot") assert path.exists() Image.open(path) # delete the file again to check if it is regenerated (needed as in different test runs the temp folder might not be fully cleared) path.unlink() path = db.get_diagram_image_path(fitness_function_key, "box-plot") assert path.exists() Image.open(path) with pytest.raises(ValueError): path = db.get_diagram_image_path(None, "distribution")