from torchvision import datasets from torchvision.transforms import ToTensor, Compose, Normalize import matplotlib.pyplot as plt import torch def load_data(): training_data = datasets.MNIST( root="./VAE/data", train=True, download=False, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="./VAE/data", train=False, download=False, transform=ToTensor() ) return training_data, test_data def create_dataloaders(train_data, test_data): ''' Creates dataloaders from tokens. Tokens should be in pytorch format and be a dictionary with fields train, test, validation ''' train_dl = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True) test_dl = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False) return train_dl, test_dl def prior_sampling(model, z_dim=64, n_samples=15): model.eval() input_sample=torch.randn(n_samples, z_dim) with torch.no_grad(): sampled_images = model.decoder(input_sample) fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 2, 2.5)) for img, ax in zip(sampled_images, axes): ax.imshow(img[0].cpu().numpy(), cmap='gray') # Ensure using .cpu() if on a CUDA device ax.axis('off') # Turn off axis numbering plt.suptitle('Generated Prior Sample') plt.show()