JSCC / VAE / utils.py
utils.py
Raw
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Normalize
import matplotlib.pyplot as plt
import torch

def load_data():
    training_data = datasets.MNIST(
        root="./VAE/data",
        train=True,
        download=False,
        transform=ToTensor()
    )
    test_data = datasets.FashionMNIST(
        root="./VAE/data",
        train=False,
        download=False,
        transform=ToTensor()
    )
    return training_data, test_data

def create_dataloaders(train_data, test_data):
    '''
    Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation
    '''
    train_dl = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)
    return train_dl, test_dl

def prior_sampling(model, z_dim=64, n_samples=15):
    model.eval()
    input_sample=torch.randn(n_samples, z_dim)
    with torch.no_grad():
        sampled_images = model.decoder(input_sample)

    fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 2, 2.5))
    for img, ax in zip(sampled_images, axes):
        ax.imshow(img[0].cpu().numpy(), cmap='gray')  # Ensure using .cpu() if on a CUDA device
        ax.axis('off')  # Turn off axis numbering

    plt.suptitle('Generated Prior Sample')
    plt.show()