import model import torch import matplotlib.pyplot as plt from utils import( load_data, create_dataloaders, prior_sampling ) def loss_function(features, decoded, z_mean, z_log_var): k = z_mean.size(-1) # latent dim loss_fn = torch.nn.functional.mse_loss kl_div = -0.5 * torch.sum(k + z_log_var - z_mean**2 - torch.exp(z_log_var), axis=1) # sum over latent dimension batchsize = kl_div.size(0) kl_div = kl_div.mean() # average over batch dimension pixelwise = loss_fn(decoded, features, reduction='none') pixelwise = pixelwise.view(batchsize, -1).sum(axis=1) # sum over pixels pixelwise = pixelwise.mean() # average over batch dimension loss = pixelwise + kl_div return loss if __name__ == "__main__": train_ds, test_ds = load_data() train_dl, test_dl = create_dataloaders(train_ds, test_ds) model = model.VAE(784, 128, 5) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) num_epochs = 10 train_losses = [] val_losses = [] for e in range(num_epochs): model.train() for i, (batch, _) in enumerate(train_dl): recon_batch = model(batch) loss = loss_function(recon_batch, batch, model.mu, model.log_var) optimizer.zero_grad() loss.backward() optimizer.step() train_losses.append(loss.item()) if i%100==0: print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1])) prior_sampling(model, z_dim=5) # model.eval() # loss_test = [] # for i, (batch, _) in enumerate(test_dl): # pred = model(batch) # if i == 1: # plt.imshow(pred[i][0].detach().numpy(), cmap='gray') # plt.show() # pixelwise = torch.nn.functional.mse_loss(pred, batch, reduction='none') # pixelwise = pixelwise.view(32, -1).sum(axis=1) # sum over pixels # pixelwise = pixelwise.mean() # loss_test.append(pixelwise.detach().numpy()) print("DONE") # print(f"mean error on test images = {sum(loss_test)/len(loss_test)}")