JSCC / VQ-VAE / utils.py
utils.py
Raw
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Normalize, CenterCrop
import matplotlib.pyplot as plt
import torch
import numpy as np

def load_data():
    training_data = datasets.CelebA(
        root="./data",
        split='train',
        download=False,
        transform=Compose([CenterCrop((128,128)), ToTensor(), Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])])
    )
    test_data = datasets.CelebA(
        root="./data",
        split='test',
        download=False,
        transform=Compose([CenterCrop((128,128)), ToTensor(), Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])])
    )
    return training_data, test_data

def create_dataloaders(train_data, test_data):
    '''
    Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation
    '''
    train_dl = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, drop_last=True)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, drop_last=True)
    return train_dl, test_dl


def show(img):
    npimg = img.numpy()
    plt.figure(figsize=(10, 10))
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.savefig("output_image.png")
    plt.close()