import os
import math
import torch
from torchvision.io import read_image
from torchvision.io import ImageReadMode as IRM
###
# Training split will be first 60 images
# test split will be last 12
###
class DigiFaceDataset(torch.utils.data.Dataset):
def __init__(self, path, split, transform=None, target_transform=None, images_per_id = 72):
self.path = path
if (split != "train") and (split != "test"):
raise Exception("Invalid split")
self.split = split
self.transform = transform
self.target_transform = target_transform
self.images_per_id = images_per_id
self.count_images(path)
def __len__(self):
return self.number_of_images
def __getitem__(self, idx):
split_count = 0
if self.split == "train":
split_count = 60
elif self.split == "test":
split_count = 12
label = math.floor(idx / split_count)
image_id = idx % split_count + (60 if self.split == "test" else 0)
image = read_image(self.path + '/' + str(label) + '/'
+ str(image_id) + '.png',
mode = IRM.RGB)
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image.float(), label
def count_images(self, path):
self.number_of_images = 0
for root, dirs, files in os.walk(path):
images = []
if self.split is "train":
images = [x for x in files if
int(x[:-4]) < 60]
elif self.split is "test":
images = [x for x in files if
int(x[:-4]) >= 60]
self.number_of_images += len(images)
#print("Hi")
#dataset = DigiFaceDataset("datasets/DigiFace/subjects_0-1999_72_imgs")
#print(dataset.number_of_images)
#print(dataset[0][0].short())
#print(dataset[0][0].shape)