CodeExamples / Face Recognition / visualisation.py
visualisation.py
Raw
from torchvision.utils import make_grid
from torchvision.io import read_image
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F


plt.rcParams["savefig.bbox"] = 'tight'

def show(imgs, layer):
    imgs = list(imgs)    
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False,figsize=(20,4.8))
    fix.figsize = (20, 10)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img), cmap='gray', vmin=0, vmax=255)
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def display_images(images, i):
    grid = make_grid(images, nrow=len(images))
    show(grid, i)

def show_all_channels(output, i):
    display_images(list(output), i)

#dog1_int = read_image(str(Path('assets') / 'dog1.jpg'))
#dog2_int = read_image(str(Path('assets') / 'dog2.jpg'))
#print("madde images")
#images = [read_image("datasets/DigiFace/subjects_0-1999_72_imgs/0/" + str(i) + ".png") for i in range(0,10)]
#print(len(images))
#display_images(images)
#print("displaed")