import csv from pathlib import Path from typing import Union import torch class CsvManager: def __init__(self, length: int, path: Path, custom_fields=None) -> None: self.path = path self.length = length self.fields = custom_fields def read_csv(self) -> None: """read the csv file from the filesystem into a torch.Tensor called data_tensor that is accessible via __get_item__ and __set_item__""" with open(self.path, mode="r", encoding="utf-8") as csv_file: self.csv_reader = csv.DictReader(csv_file) self.fields = self.csv_reader.fieldnames self.csv_reader = iter(self.csv_reader) self.data_tensor = torch.tensor( [list(map(int, row.values())) for row in self.csv_reader], dtype=torch.int64, ) def __getitem__(self, index: Union[int, slice]) -> torch.Tensor: return self.data_tensor[index] def __setitem__(self, index: Union[int, slice], data: torch.Tensor): self.data_tensor[index] = data def write_csv(self) -> None: """write the data_tensor from back to the csv_file to the filesystem""" # Exception handling should not be necessary here, but there is # seemingly a bug with `exist_ok = True` not suppressing the Exception try: self.path.mkdir(parents=True, exist_ok=True) except FileExistsError: pass with open(self.path, mode="w", encoding="utf-8") as csv_file: csv_writer = csv.writer( csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL, ) csv_writer.writerow(self.fields) for row in self.data_tensor: csv_writer.writerow([*list(map(int, row))]) def initialize_csv(self) -> None: """initialize csv file in the filesystem (use column fileds as defined in the strategy and set every value to -1)""" with open(self.path, mode="w", encoding="utf-8") as csv_file: csv_writer = csv.writer( csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL, ) csv_writer.writerow(self.fields) for _ in range(self.length): csv_writer.writerow([-1] * len(self.fields)) def delete(self): self.path.unlink()