traffic-sign-classifier-robustness-testing / rt_search_based / datasets / datasets.py
datasets.py
Raw
import itertools
import logging
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Union

import torchvision as tv
from torch import Tensor


@dataclass
class DatasetItem:
    index: int
    image: Tensor
    label: Any


class Dataset(ABC):
    """
    Dataset class that serves as a base class for all datasets.
    It can return the full list of DatasetItem objects in the dataset,
    return a single DatasetItem object provided the index,
    randomize the dataset, and return the number of items in the dataset.
    """

    _dataset: List[DatasetItem]

    @abstractmethod
    def __init__(self) -> None:
        """
        Initializes dataset
        """

    @property
    def dataset(self) -> List[DatasetItem]:
        """
        Returns complete dataset as list of DatasetItems
        """
        return self._dataset

    def randomize(self, random_seed: Union[int, None] = None) -> None:
        """
        Randomizes dataset
        """
        random.seed(random_seed)
        random.shuffle(self._dataset)

    def __getitem__(self, subscript) -> Union[DatasetItem, List[DatasetItem]]:
        return self._dataset.__getitem__(subscript)

    def __len__(self) -> int:
        return len(self._dataset)

    def __repr__(self) -> str:
        return f"{type(self).__name__}()"


class GtsrbDataset(Dataset):
    """
    Test Dataset for the GTSRB Dataset contains 12492 images
    """

    def __init__(self) -> None:
        logging.info(
            f"Starting to initialize {type(self).__name__}, this could take a while"
        )
        GTBSRB_dataset = tv.datasets.GTSRB(
            root="./rt_search_based/datasets/data/",
            split="test",
            download=True,
            transform=tv.transforms.Compose(
                [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()]
            ),
        )

        self._dataset = [
            DatasetItem(index, image, label)
            for index, (image, label) in enumerate(GTBSRB_dataset)
        ]
        logging.info(f"Finished initializing {type(self).__name__}")


class MiniGtsrbDataset(Dataset):
    """
    First n images from GTSRB Dataset, used during development to speed up testing
    """

    def __init__(self, length: int = 50) -> None:
        logging.info(
            f"Starting to initialize {self.__class__.__name__}, this could take a while"
        )

        self.length = length

        # make sure the requested length is less than the original dataset
        if self.length >= 12492:
            raise ValueError("Requested length is larger than original dataset")

        GTBSRB_dataset = tv.datasets.GTSRB(
            root="./rt_search_based/datasets/data/",
            split="test",
            download=True,
            transform=tv.transforms.Compose(
                [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()]
            ),
        )

        first_n_in_gtsrb = itertools.islice(GTBSRB_dataset, self.length)

        self._dataset = [
            DatasetItem(index, image, label)
            for index, (image, label) in enumerate(first_n_in_gtsrb)
        ]
        logging.info(f"Finished initializing {type(self).__name__}")

    def __repr__(self) -> str:
        return f"{type(self).__name__}({self.length})"


GTBSRB_dataset = tv.datasets.GTSRB(
    root="./rt_search_based/datasets/data/",
    split="test",
    download=True,
    transform=tv.transforms.Compose(
        [tv.transforms.Resize((64, 64)), tv.transforms.ToTensor()]
    ),
)