from os import preadv import model_soft import model_attn import utils import torch import torch.nn.functional as F import matplotlib.pyplot as plt from torchvision.utils import make_grid import numpy as np if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu" if torch.backends.mps.is_available() else "cpu") print(f"Using device = {device}") embedding_dim = 64 num_embeddings = 512 commitment_cost = 0.25 learning_rate = 1e-3 train_ds, test_ds = utils.load_data() train_dl, test_dl = utils.create_dataloaders(train_ds, test_ds) model_attn = model_attn.VQVAE(3, num_embeddings, embedding_dim).to(device) model_attn.load_state_dict(torch.load('model_attn_noise.pth')) model_soft = model_soft.VQVAE(3, num_embeddings, embedding_dim).to(device) model_soft.load_state_dict(torch.load('model_soft_noise.pth')) ## Training ## # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # num_epochs = 2 # train_losses = [] # for e in range(num_epochs): # model.train() # for i, (batch, _) in enumerate(train_dl): # batch = batch.to(device) # vq_loss, pred = model(batch) # recon_loss = F.mse_loss(pred, batch) # loss = vq_loss + recon_loss # optimizer.zero_grad() # loss.backward() # optimizer.step() # train_losses.append(loss.item()) # if i%100==0: # print("Training epoch {}, iteration {} of {} ({} %), loss={:0.5f}".format(e, i, len(train_dl), 100*i//len(train_dl), train_losses[-1])) # torch.save(model.state_dict(), 'model_attn_noise.pth') ## Testing ## model_attn.eval() model_soft.eval() # batch, _ = next(iter(test_dl)) noise_arr = [None, 0.01, 0.05, 0.1, 0.3, 0.5, 1] recon_loss_attn = [] for noise in noise_arr: tmp = [] for i, (batch, _) in enumerate(test_dl): with torch.no_grad(): batch = batch.to(device) _, pred = model_attn(batch, noise) recon_loss = F.mse_loss(pred, batch) tmp.append(recon_loss.cpu()) recon_loss_attn.append(np.mean(tmp)) tmp.clear() recon_loss_soft = [] for noise in noise_arr: tmp = [] for i, (batch, _) in enumerate(test_dl): with torch.no_grad(): batch = batch.to(device) _, pred = model_soft(batch, noise) recon_loss = F.mse_loss(pred, batch) tmp.append(recon_loss.cpu()) recon_loss_soft.append(np.mean(tmp)) tmp.clear() plt.plot(noise_arr, recon_loss_attn, label="Validation loss attentive", marker="o") plt.plot(noise_arr, recon_loss_soft, label="Validation loss soft", marker="o") plt.title("Reconstruction Error") plt.xlabel("Noise") plt.ylabel("Average MSE") plt.legend() plt.savefig("val_loss.png") plt.close() # fig, axes = plt.subplots(2, len(noise_arr) + 1, figsize=(15, 6)) # 2 rows, N+1 columns # for i, images in enumerate(zip(output_images_attn, output_images_soft)): # for j, image in enumerate(images): # img_grid = make_grid(image[0].cpu(), nrow=1) # Create a grid for the batch # img_grid = np.transpose(img_grid.numpy(), (1, 2, 0)) # axes[j, i].imshow(img_grid, interpolation='nearest') # if i == 0: # axes[j, i].set_title("Model Attention Input" if j == 0 else "Model Soft Input") # else: # axes[j, i].set_title(f"Noise: {noise_arr[i-1]}") # axes[j, i].axis('off') # plt.tight_layout() # plt.savefig("image_noise_comparison.png") # plt.close() ## For plotting images ### # model.eval() # batch, _ = next(iter(test_dl)) # output_images = [batch] # noise_arr = [0.01, 0.05, 0.1, 0.15, 0.5, 1] # for noise in noise_arr: # with torch.no_grad(): # _, reconstructed_ = model(batch.to(device), noise) # output_images.append(reconstructed_) # fig, axes = plt.subplots(1, len(output_images), figsize=(15, 2)) # 1 row, N+1 columns # for j, image in enumerate(output_images): # img_grid = make_grid(image[0].cpu(), nrow=1) # Create a grid for the batch # img_grid = np.transpose(img_grid.numpy(), (1, 2, 0)) # axes[j].imshow(img_grid, interpolation='nearest') # if j == 0: # axes[j].set_title("Input Image") # else: # axes[j].set_title(f"Noise: {noise_arr[j-1]}") # axes[j].axis('off') # plt.tight_layout() # plt.savefig("image_attn_noise.png") # plt.close()