from torchvision.utils import make_grid from torchvision.io import read_image from pathlib import Path import torch import numpy as np import matplotlib.pyplot as plt import torchvision.transforms.functional as F plt.rcParams["savefig.bbox"] = 'tight' def show(imgs, layer): imgs = list(imgs) if not isinstance(imgs, list): imgs = [imgs] fix, axs = plt.subplots(ncols=len(imgs), squeeze=False,figsize=(20,4.8)) fix.figsize = (20, 10) for i, img in enumerate(imgs): img = img.detach() img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img), cmap='gray', vmin=0, vmax=255) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) def display_images(images, i): grid = make_grid(images, nrow=len(images)) show(grid, i) def show_all_channels(output, i): display_images(list(output), i) #dog1_int = read_image(str(Path('assets') / 'dog1.jpg')) #dog2_int = read_image(str(Path('assets') / 'dog2.jpg')) #print("madde images") #images = [read_image("datasets/DigiFace/subjects_0-1999_72_imgs/0/" + str(i) + ".png") for i in range(0,10)] #print(len(images)) #display_images(images) #print("displaed")