from torchvision import datasets from torchvision.transforms import ToTensor, Compose, Normalize, CenterCrop import matplotlib.pyplot as plt import torch import numpy as np def load_data(): training_data = datasets.CelebA( root="./data", split='train', download=False, transform=Compose([CenterCrop((128,128)), ToTensor(), Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])]) ) test_data = datasets.CelebA( root="./data", split='test', download=False, transform=Compose([CenterCrop((128,128)), ToTensor(), Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])]) ) return training_data, test_data def create_dataloaders(train_data, test_data): ''' Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation ''' train_dl = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, drop_last=True) test_dl = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, drop_last=True) return train_dl, test_dl def show(img): npimg = img.numpy() plt.figure(figsize=(10, 10)) fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.savefig("output_image.png") plt.close()