from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import torch
class TinyDataset(Dataset):
def __init__(self, hf_dataset, transform=None):
self.hf_dataset = hf_dataset
self.transform = transform
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
example = self.hf_dataset[idx]
image = example['image']
label = example['label']
if self.transform:
image = self.transform(image)
if image.shape[0] == 1:
return None
else:
return image, label
def load_tiny(batch_size=64, shuffle=False, is_train=False):
imgSize = 224
toTensorTransform = transforms.Compose([
transforms.Resize(imgSize),
transforms.ToTensor(),
])
if is_train == True:
tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
custom_train_dataset = TinyDataset(tiny_imagenet_train, transform=toTensorTransform)
custom_train_dataset = [item for item in custom_train_dataset if item is not None]
train_loader = DataLoader(custom_train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)
return train_loader
else:
tiny_imagenet_val = load_dataset('Maysee/tiny-imagenet', split='valid')
val_dataset = TinyDataset(tiny_imagenet_val, transform=toTensorTransform)
val_dataset = [item for item in val_dataset if item is not None]
np.random.seed(0)
val_index = np.random.choice(len(val_dataset), int(0.2*len(val_dataset)), replace=False)
custom_val_dataset = [val_dataset[i] for i in val_index]
val_loader = DataLoader(custom_val_dataset, batch_size=batch_size, shuffle=shuffle)
# test_index = np.setdiff1d(np.arange(len(val_dataset)), val_index)
# custom_test_dataset = [val_dataset[i] for i in test_index]
# test_loader = DataLoader(custom_test_dataset, batch_size=batch_size, shuffle=shuffle)
return val_loader
def GetCIFAR100Training(imgSize = 32, batchSize=64):
toTensorTransform = transforms.Compose([
transforms.Resize(imgSize),
transforms.ToTensor(),
])
trainLoader = torch.utils.data.DataLoader(datasets.CIFAR100(root='dataset/', train=True, download=True, transform=toTensorTransform), batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
return trainLoader
def GetCIFAR100Validation(imgSize = 32, batchSize=64, ratio=1):
transformTest = transforms.Compose([
transforms.Resize(imgSize),
transforms.ToTensor(),
])
val_dataset = datasets.CIFAR100(root= 'dataset/', train=False, download=True, transform=transformTest)
np.random.seed(0)
test_index = np.random.choice(len(val_dataset),int(ratio*len(val_dataset)),replace=False)
val_dataset = [val_dataset[i] for i in test_index]
valLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
return valLoader
def GetCIFAR10Training(imgSize = 32, batchSize=64):
toTensorTransform = transforms.Compose([
transforms.Resize(imgSize),
transforms.ToTensor(),
])
trainLoader = torch.utils.data.DataLoader(datasets.CIFAR10(root='dataset/', train=True, download=True, transform=toTensorTransform), batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
return trainLoader
def GetCIFAR10Validation(imgSize = 32, batchSize=64, ratio=1):
transformTest = transforms.Compose([
transforms.Resize(imgSize),
transforms.ToTensor(),
])
val_dataset = datasets.CIFAR10(root='dataset/', train=False, download=True, transform=transformTest)
np.random.seed(0)
test_index = np.random.choice(len(val_dataset),int(ratio*len(val_dataset)),replace=False)
val_dataset = [val_dataset[i] for i in test_index]
valLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
return valLoader