JSCC / VQ-VAE / main.py
main.py
Raw
from os import preadv
import model_soft
import model_attn
import utils
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np

if __name__ == "__main__": 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device = {device}")
    embedding_dim = 64
    num_embeddings = 512
    commitment_cost = 0.25
    learning_rate = 1e-3

    train_ds, test_ds = utils.load_data()
    train_dl, test_dl = utils.create_dataloaders(train_ds, test_ds)

    model_attn = model_attn.VQVAE(3, num_embeddings, embedding_dim).to(device)
    model_attn.load_state_dict(torch.load('model_attn_noise.pth'))

    model_soft = model_soft.VQVAE(3, num_embeddings, embedding_dim).to(device)
    model_soft.load_state_dict(torch.load('model_soft_noise.pth'))
        
    ## Training ##
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # num_epochs = 2
    # train_losses = []

    # for e in range(num_epochs):
    #     model.train()
    #     for i, (batch, _) in enumerate(train_dl): 
    #         batch = batch.to(device)
    #         vq_loss, pred = model(batch)
    #         recon_loss = F.mse_loss(pred, batch) 
    #         loss = vq_loss + recon_loss
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #         train_losses.append(loss.item())

    #         if i%100==0:
    #             print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1]))

    # torch.save(model.state_dict(), 'model_attn_noise.pth')

    ## Testing ##
    model_attn.eval()
    model_soft.eval()
    # batch, _ = next(iter(test_dl))
    noise_arr = [None, 0.01, 0.05, 0.1, 0.3, 0.5, 1]
    recon_loss_attn = []
    for noise in noise_arr:
        tmp = []
        for i, (batch, _) in enumerate(test_dl):
            with torch.no_grad():
                batch = batch.to(device)
                _, pred = model_attn(batch, noise)
                recon_loss = F.mse_loss(pred, batch) 
                tmp.append(recon_loss.cpu())
        recon_loss_attn.append(np.mean(tmp))
        tmp.clear()

    recon_loss_soft = []
    for noise in noise_arr:
        tmp = []
        for i, (batch, _) in enumerate(test_dl):
            with torch.no_grad():
                batch = batch.to(device)
                _, pred = model_soft(batch, noise)
                recon_loss = F.mse_loss(pred, batch) 
                tmp.append(recon_loss.cpu())
        recon_loss_soft.append(np.mean(tmp))
        tmp.clear()
      

    plt.plot(noise_arr, recon_loss_attn, label="Validation loss attentive", marker="o")
    plt.plot(noise_arr, recon_loss_soft, label="Validation loss soft", marker="o")
    plt.title("Reconstruction Error")
    plt.xlabel("Noise")
    plt.ylabel("Average MSE")
    plt.legend()
    plt.savefig("val_loss.png")
    plt.close()
    
    # fig, axes = plt.subplots(2, len(noise_arr) + 1, figsize=(15, 6))  # 2 rows, N+1 columns

    # for i, images in enumerate(zip(output_images_attn, output_images_soft)):
    #     for j, image in enumerate(images):
    #         img_grid = make_grid(image[0].cpu(), nrow=1)  # Create a grid for the batch
    #         img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
            
    #         axes[j, i].imshow(img_grid, interpolation='nearest')
    #         if i == 0:
    #             axes[j, i].set_title("Model Attention Input" if j == 0 else "Model Soft Input")
    #         else:
    #             axes[j, i].set_title(f"Noise: {noise_arr[i-1]}")

    #         axes[j, i].axis('off')

    # plt.tight_layout()
    # plt.savefig("image_noise_comparison.png")
    # plt.close()


    ## For plotting images ###
    # model.eval()
    # batch, _ = next(iter(test_dl))
    # output_images = [batch]
    # noise_arr = [0.01, 0.05, 0.1, 0.15, 0.5, 1]
    # for noise in noise_arr:
    #     with torch.no_grad():
    #         _, reconstructed_ = model(batch.to(device), noise)
    #         output_images.append(reconstructed_)

    # fig, axes = plt.subplots(1, len(output_images), figsize=(15, 2))  # 1 row, N+1 columns

    # for j, image in enumerate(output_images):
    #     img_grid = make_grid(image[0].cpu(), nrow=1)  # Create a grid for the batch
    #     img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
        
    #     axes[j].imshow(img_grid, interpolation='nearest')
    #     if j == 0:
    #         axes[j].set_title("Input Image")
    #     else:
    #         axes[j].set_title(f"Noise: {noise_arr[j-1]}")
    #     axes[j].axis('off')

    # plt.tight_layout()
    # plt.savefig("image_attn_noise.png")
    # plt.close()