traffic-sign-classifier-robustness-testing / rt_search_based / database / csv_manager.py
csv_manager.py
Raw
import csv
from pathlib import Path
from typing import Union

import torch


class CsvManager:
    def __init__(self, length: int, path: Path, custom_fields=None) -> None:
        self.path = path
        self.length = length
        self.fields = custom_fields

    def read_csv(self) -> None:
        """read the csv file from the filesystem into a torch.Tensor called data_tensor that is accessible via __get_item__ and __set_item__"""
        with open(self.path, mode="r", encoding="utf-8") as csv_file:
            self.csv_reader = csv.DictReader(csv_file)

            self.fields = self.csv_reader.fieldnames

            self.csv_reader = iter(self.csv_reader)
            self.data_tensor = torch.tensor(
                [list(map(int, row.values())) for row in self.csv_reader],
                dtype=torch.int64,
            )

    def __getitem__(self, index: Union[int, slice]) -> torch.Tensor:
        return self.data_tensor[index]

    def __setitem__(self, index: Union[int, slice], data: torch.Tensor):
        self.data_tensor[index] = data

    def write_csv(self) -> None:
        """write the data_tensor from back to the csv_file to the filesystem"""

        # Exception handling should not be necessary here, but there is
        # seemingly a bug with `exist_ok = True` not suppressing the Exception
        try:
            self.path.mkdir(parents=True, exist_ok=True)
        except FileExistsError:
            pass

        with open(self.path, mode="w", encoding="utf-8") as csv_file:
            csv_writer = csv.writer(
                csv_file,
                delimiter=",",
                quotechar='"',
                quoting=csv.QUOTE_MINIMAL,
            )

            csv_writer.writerow(self.fields)
            for row in self.data_tensor:
                csv_writer.writerow([*list(map(int, row))])

    def initialize_csv(self) -> None:
        """initialize csv file in the filesystem (use column fileds as defined in the strategy and set every value to -1)"""
        with open(self.path, mode="w", encoding="utf-8") as csv_file:
            csv_writer = csv.writer(
                csv_file,
                delimiter=",",
                quotechar='"',
                quoting=csv.QUOTE_MINIMAL,
            )
            csv_writer.writerow(self.fields)
            for _ in range(self.length):
                csv_writer.writerow([-1] * len(self.fields))

    def delete(self):
        self.path.unlink()