import os import math import torch from torchvision.io import read_image from torchvision.io import ImageReadMode as IRM ### # Training split will be first 60 images # test split will be last 12 ### class DigiFaceDataset(torch.utils.data.Dataset): def __init__(self, path, split, transform=None, target_transform=None, images_per_id = 72): self.path = path if (split != "train") and (split != "test"): raise Exception("Invalid split") self.split = split self.transform = transform self.target_transform = target_transform self.images_per_id = images_per_id self.count_images(path) def __len__(self): return self.number_of_images def __getitem__(self, idx): split_count = 0 if self.split == "train": split_count = 60 elif self.split == "test": split_count = 12 label = math.floor(idx / split_count) image_id = idx % split_count + (60 if self.split == "test" else 0) image = read_image(self.path + '/' + str(label) + '/' + str(image_id) + '.png', mode = IRM.RGB) if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image.float(), label def count_images(self, path): self.number_of_images = 0 for root, dirs, files in os.walk(path): images = [] if self.split is "train": images = [x for x in files if int(x[:-4]) < 60] elif self.split is "test": images = [x for x in files if int(x[:-4]) >= 60] self.number_of_images += len(images) #print("Hi") #dataset = DigiFaceDataset("datasets/DigiFace/subjects_0-1999_72_imgs") #print(dataset.number_of_images) #print(dataset[0][0].short()) #print(dataset[0][0].shape)