JSCC / VAE / main.py
main.py
Raw
import model
import torch
import matplotlib.pyplot as plt
from utils import(
    load_data,
    create_dataloaders,
    prior_sampling
)

def loss_function(features, decoded, z_mean, z_log_var):
    k = z_mean.size(-1) # latent dim

    loss_fn = torch.nn.functional.mse_loss
    kl_div = -0.5 * torch.sum(k + z_log_var - z_mean**2 - torch.exp(z_log_var), axis=1) # sum over latent dimension

    batchsize = kl_div.size(0)
    kl_div = kl_div.mean() # average over batch dimension

    pixelwise = loss_fn(decoded, features, reduction='none')
    pixelwise = pixelwise.view(batchsize, -1).sum(axis=1) # sum over pixels
    pixelwise = pixelwise.mean() # average over batch dimension

    loss = pixelwise + kl_div
    return loss

if __name__ == "__main__": 

    train_ds, test_ds = load_data()
    train_dl, test_dl = create_dataloaders(train_ds, test_ds)

    model = model.VAE(784, 128, 5)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 10
    train_losses = []
    val_losses = []

    for e in range(num_epochs):
        model.train()
        for i, (batch, _) in enumerate(train_dl): 

            recon_batch = model(batch)
            loss = loss_function(recon_batch, batch, model.mu, model.log_var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            if i%100==0:
                print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1]))

    prior_sampling(model, z_dim=5)

    # model.eval()
    # loss_test = []
    # for i, (batch, _) in enumerate(test_dl):
    #     pred = model(batch)
    #     if i == 1:
    #         plt.imshow(pred[i][0].detach().numpy(), cmap='gray')
    #         plt.show()

    #     pixelwise = torch.nn.functional.mse_loss(pred, batch, reduction='none')
    #     pixelwise = pixelwise.view(32, -1).sum(axis=1) # sum over pixels
    #     pixelwise = pixelwise.mean()
    #     loss_test.append(pixelwise.detach().numpy())
    
    print("DONE")
    # print(f"mean error on test images = {sum(loss_test)/len(loss_test)}")