CodeExamples / Face Recognition / datasets.py
datasets.py
Raw
import os
import math


import torch
from torchvision.io import read_image
from torchvision.io import ImageReadMode as IRM

###
#   Training split will be first 60 images
#   test split will be last 12
###

class DigiFaceDataset(torch.utils.data.Dataset):
    def __init__(self, path, split, transform=None, target_transform=None, images_per_id = 72):        
        self.path = path
        if (split != "train") and (split != "test"):
            raise Exception("Invalid split")
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.images_per_id = images_per_id
        self.count_images(path)

    def __len__(self):
        return self.number_of_images

    def __getitem__(self, idx):  
        split_count = 0
        if self.split == "train":
            split_count = 60
        elif self.split == "test":
            split_count = 12
        label = math.floor(idx / split_count)
        image_id = idx % split_count + (60 if self.split == "test" else 0)
        image = read_image(self.path + '/' + str(label) + '/' 
                           + str(image_id) + '.png',
                           mode = IRM.RGB)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image.float(), label
    
    def count_images(self, path):
        self.number_of_images = 0
        for root, dirs, files in os.walk(path):
            images = []
            if self.split is "train":
                images = [x for x in files if 
                          int(x[:-4]) < 60]
            elif self.split is "test":
                images = [x for x in files if 
                          int(x[:-4]) >= 60]
            self.number_of_images += len(images)

#print("Hi")
#dataset = DigiFaceDataset("datasets/DigiFace/subjects_0-1999_72_imgs")
#print(dataset.number_of_images)

#print(dataset[0][0].short())
#print(dataset[0][0].shape)