ViTGuard / load_data.py
load_data.py
Raw
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import torch

class TinyDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        example = self.hf_dataset[idx]
        image = example['image']  
        label = example['label']  
        if self.transform:
            image = self.transform(image)
        if image.shape[0] == 1:
            return None
        else:
            return image, label

def load_tiny(batch_size=64, shuffle=False, is_train=False):
    imgSize = 224
    toTensorTransform = transforms.Compose([
        transforms.Resize(imgSize),
        transforms.ToTensor(),
    ])

    if is_train == True:
        tiny_imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
        custom_train_dataset = TinyDataset(tiny_imagenet_train, transform=toTensorTransform)
        custom_train_dataset = [item for item in custom_train_dataset if item is not None]
        train_loader = DataLoader(custom_train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)    
        return train_loader
    else:
        tiny_imagenet_val = load_dataset('Maysee/tiny-imagenet', split='valid')
        val_dataset = TinyDataset(tiny_imagenet_val, transform=toTensorTransform)
        val_dataset = [item for item in val_dataset if item is not None]
        np.random.seed(0)
        val_index = np.random.choice(len(val_dataset), int(0.2*len(val_dataset)), replace=False)
        custom_val_dataset = [val_dataset[i] for i in val_index]
        val_loader = DataLoader(custom_val_dataset, batch_size=batch_size, shuffle=shuffle) 

        # test_index = np.setdiff1d(np.arange(len(val_dataset)), val_index)
        # custom_test_dataset = [val_dataset[i] for i in test_index]
        # test_loader = DataLoader(custom_test_dataset, batch_size=batch_size, shuffle=shuffle) 
        return val_loader

def GetCIFAR100Training(imgSize = 32, batchSize=64):
    toTensorTransform = transforms.Compose([
        transforms.Resize(imgSize),
        transforms.ToTensor(),
    ])
    trainLoader = torch.utils.data.DataLoader(datasets.CIFAR100(root='dataset/', train=True, download=True, transform=toTensorTransform), batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
    return trainLoader

def GetCIFAR100Validation(imgSize = 32, batchSize=64, ratio=1):
    transformTest = transforms.Compose([
        transforms.Resize(imgSize),
        transforms.ToTensor(),
    ])
    val_dataset = datasets.CIFAR100(root= 'dataset/', train=False, download=True, transform=transformTest)
    np.random.seed(0)
    test_index = np.random.choice(len(val_dataset),int(ratio*len(val_dataset)),replace=False)
    val_dataset = [val_dataset[i] for i in test_index]
    valLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
    return valLoader

def GetCIFAR10Training(imgSize = 32, batchSize=64):
    toTensorTransform = transforms.Compose([
        transforms.Resize(imgSize),
        transforms.ToTensor(),
    ])
    trainLoader = torch.utils.data.DataLoader(datasets.CIFAR10(root='dataset/', train=True, download=True, transform=toTensorTransform), batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
    return trainLoader

def GetCIFAR10Validation(imgSize = 32, batchSize=64, ratio=1):
    transformTest = transforms.Compose([
        transforms.Resize(imgSize),
        transforms.ToTensor(),
    ])
    val_dataset = datasets.CIFAR10(root='dataset/', train=False, download=True, transform=transformTest)
    np.random.seed(0)
    test_index = np.random.choice(len(val_dataset),int(ratio*len(val_dataset)),replace=False)
    val_dataset = [val_dataset[i] for i in test_index]
    valLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=1, pin_memory=True)
    return valLoader