import itertools import logging import random from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, List, Union import torchvision as tv from torch import Tensor @dataclass class DatasetItem: index: int image: Tensor label: Any class Dataset(ABC): """ Dataset class that serves as a base class for all datasets. It can return the full list of DatasetItem objects in the dataset, return a single DatasetItem object provided the index, randomize the dataset, and return the number of items in the dataset. """ _dataset: List[DatasetItem] @abstractmethod def __init__(self) -> None: """ Initializes dataset """ @property def dataset(self) -> List[DatasetItem]: """ Returns complete dataset as list of DatasetItems """ return self._dataset def randomize(self, random_seed: Union[int, None] = None) -> None: """ Randomizes dataset """ random.seed(random_seed) random.shuffle(self._dataset) def __getitem__(self, subscript) -> Union[DatasetItem, List[DatasetItem]]: return self._dataset.__getitem__(subscript) def __len__(self) -> int: return len(self._dataset) def __repr__(self) -> str: return f"{type(self).__name__}()" class GtsrbDataset(Dataset): """ Test Dataset for the GTSRB Dataset contains 12492 images """ def __init__(self) -> None: logging.info( f"Starting to initialize {type(self).__name__}, this could take a while" ) GTBSRB_dataset = tv.datasets.GTSRB( root="./rt_search_based/datasets/data/", split="test", download=True, transform=tv.transforms.Compose( [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()] ), ) self._dataset = [ DatasetItem(index, image, label) for index, (image, label) in enumerate(GTBSRB_dataset) ] logging.info(f"Finished initializing {type(self).__name__}") class MiniGtsrbDataset(Dataset): """ First n images from GTSRB Dataset, used during development to speed up testing """ def __init__(self, length: int = 50) -> None: logging.info( f"Starting to initialize {self.__class__.__name__}, this could take a while" ) self.length = length # make sure the requested length is less than the original dataset if self.length >= 12492: raise ValueError("Requested length is larger than original dataset") GTBSRB_dataset = tv.datasets.GTSRB( root="./rt_search_based/datasets/data/", split="test", download=True, transform=tv.transforms.Compose( [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()] ), ) first_n_in_gtsrb = itertools.islice(GTBSRB_dataset, self.length) self._dataset = [ DatasetItem(index, image, label) for index, (image, label) in enumerate(first_n_in_gtsrb) ] logging.info(f"Finished initializing {type(self).__name__}") def __repr__(self) -> str: return f"{type(self).__name__}({self.length})" GTBSRB_dataset = tv.datasets.GTSRB( root="./rt_search_based/datasets/data/", split="test", download=True, transform=tv.transforms.Compose( [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()] ), )