borderownership / src / rf_mapping / image.py
image.py
Raw
"""
Code to for image processing, etc.

Tony Fu, Jun 25, 2022
"""
import os
import sys
import math

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

sys.path.append('../..')
import src.rf_mapping.constants as c


#######################################.#######################################
#                                                                             #
#                                    CLIP                                     #
#                                                                             #
###############################################################################
def clip(x, x_min, x_max):
    """Limits x to be x_min <= x <= x_max."""
    x = min(x_max, x)
    x = max(x_min, x)
    return x


#######################################.#######################################
#                                                                             #
#                               NORMALIZE_IMG                                 #
#                                                                             #
###############################################################################
def normalize_img(img):
    """Normalizes pixel values to be roughly Norm(0, 1)."""
    norm_img = img - img.mean()
    if not math.isclose(norm_img.std(), 0):
        norm_img = norm_img/norm_img.std()
    return norm_img


#######################################.#######################################
#                                                                             #
#                           IMG PREPROCESS FUNCTIONS                          #
#                                                                             #
###############################################################################
def preprocess_img_to_tensor(img, img_size=None):
    """
    Preprocesses an image numpy array into a normalized tensor before
    presenting it to a Pytorch model. Adjusts the dimensions if necessary.
    """
    # This normalization only applies to the imagenet dataset we have.
    # transform = T.Compose([
    #             T.ToTensor(),
    #             T.Normalize(
    #     mean=[-0.01618503, -0.01468056, -0.01345447],
    #     std=[0.45679083, 0.44849625, 0.44975275],),
    # ])
    transform = T.Compose([
                T.ToTensor(),
    ])
    if len(img.shape) == 2:
        img_temp = np.zeros((3, img.shape[0], img.shape[1]))
        img_temp[0] = img
        img_temp[1] = img
        img_temp[2] = img
        img = img_temp
    
    if img.shape.index(3) == 0:
        img = np.transpose(img, (1, 2, 0))
    img_tensor = transform(img).type('torch.FloatTensor')
    
    img_tensor = torch.unsqueeze(img_tensor, dim=0)

    if img_size is not None:
        resize = T.Resize(img_size)
        img_tensor = resize(img_tensor)

    return img_tensor.to(c.DEVICE)


def preprocess_img_for_plot(img, norm=True):
    """
    Normalizes an image and makes sure it is in (height, width, 3) format.
    """
    if norm:
        img = img - img.min()
        img = img/img.max()
    if (len(img.shape)== 2):
        return img
    if (len(img.shape) == 4):
        img = np.squeeze(img)
    if (img.shape.index(3) == 0):
        img = np.transpose(img,(1,2,0))
    return img


def tensor_to_img(img_tensor):
    """
    Converts img_tensor into a numpy array that can be plotted by plt.imshow().
    """
    img = img_tensor.clone().detach()

    if len(img.shape) == 3 and img.shape[1] == 3:
        # If the image has RGB channels.
        return np.transpose(torch.squeeze(img),(1,2,0))
    else:
        # Plot only the first channel if there are more than three of them.
        return img[0,0,...].numpy()


#######################################.#######################################
#                                                                             #
#                                  MAKE_BOX                                   #
#                                                                             #
###############################################################################
def make_box(box_indices, linewidth=1):
    """
    Given box indices in (vx_min, hx_min, vx_max, hx_max) format, returns a
    matplotlib.patches.Rectangle object. Example usage:

        plt.imshow(img)
        ax = plt.gca()
        rect = make_box((0, 0, 100, 50))
        ax.add_patch(rect)

    This script plots a red rectangle box with height 100 and width 50 on the
    top-left corner of the img.
    """
    vx_min, hx_min, vx_max, hx_max = box_indices
    top_left = (hx_min, vx_min)  # (x, y) format.
    height = vx_max - vx_min + 1
    width = hx_max - hx_min + 1
    rect = patches.Rectangle(top_left, width, height, linewidth=linewidth,
                             edgecolor='r', facecolor='none')
    return rect


#######################################.#######################################
#                                                                             #
#                                 IMG DATASET                                 #
#                                                                             #
###############################################################################
class ImgDataset(Dataset):
    """
    A Dataset object of image dataset located in a directory. The generator
    returns torch.tensor that has been normalized to be roughly between
    [-1, 1]. Note that the labels of image are all set to zero because this
    project does not care about the labels. Example usage:

        num_image = 100
        img_dir = Path(__file__).parent.parent.parent.joinpath('data/imagenet')
        img_names = [f"{i}.npy" for i in range(num_images)]
        imagenet_data = ImgDataset(img_dir, img_names)
        for img, label in imagenet_data:
            outputs = model(img)

    This fatches the first 100 images from the imagenet folder and presents
    them to the model.
    """
    def __init__(self, img_dir, img_names):
        """
        Contructs an ImgDataset object.
        
        Parameters
        ----------
        img_dir : str or path-like
            The directory of the images.
        img_names : list of strs or pth-likes
            The names of the image files.
        """
        self.img_dir = img_dir
        self.img_names = img_names
        self.transform = preprocess_img_to_tensor

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        img = np.load(img_path)
        img_tensor = self.transform(img)
        label = 0  # Don't care about labels.

        return img_tensor, label


if __name__ == "__main__":
    """
    Testing ImgDataset: plot the first <num_images> images and plot the image
    no.<image_idx>.
    """
    num_images = 100
    image_idx = 0
    
    img_dir = c.REPO_DIR + '/data/imagenet'
    img_names = [f"{i}.npy" for i in range(num_images)]
    imagenet_data = ImgDataset(img_dir, img_names)
    imagenet_dataloader = DataLoader(imagenet_data, batch_size=64, shuffle=False)
    imgs, _ = next(iter(imagenet_dataloader))
    img = preprocess_img_for_plot(imgs[image_idx])
    plt.imshow(img)
    plt.show()


#######################################.#######################################
#                                                                             #
#                              ONE_SIDED_ZERO_PAD                             #
#                                                                             #
###############################################################################
def one_sided_zero_pad(patch, desired_size, box):
    """
    Return original patch if it is the right size. Assumes that the patch
    given is always smaller or equal to the desired size. The box tells us
    the spatial location of the patch on the image.
    """
    if patch.shape == desired_size:
        return patch

    vx_min, hx_min, vx_max, hx_max = box
    touching_top_edge = (vx_min == 0)
    touching_left_edge = (hx_min == 0)

    padded_patch = np.zeros((desired_size[0], desired_size[1]))
    patch_h, patch_w = patch.shape

    if touching_top_edge and touching_top_edge:
        padded_patch[-patch_h:, -patch_w:] = patch  # fill from bottom right
    elif touching_top_edge:
        padded_patch[-patch_h:, :patch_w] = patch  # fill from bottom left
    elif touching_left_edge:
        padded_patch[:patch_h, -patch_w:] = patch  # fill from top right
    else:
        padded_patch[:patch_h, :patch_w] = patch  # fill from top left

    return padded_patch


if __name__ == "__main__":
    patch_h, patch_w = (20, 15)
    patch = np.ones((patch_h, patch_w))
    desired_size = (50, 40)
    
    vx_min = 0
    hx_min = 0
    box = (vx_min, hx_min, vx_min + patch_h, hx_min + patch_w)
    padded_patch = one_sided_zero_pad(patch, desired_size, box)
    print(box)
    plt.imshow(padded_patch)