borderownership / src / rf_mapping / guided_backprop.py
guided_backprop.py
Raw
import os
import sys
import copy
import warnings

import numpy as np
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
import matplotlib.pyplot as plt

sys.path.append('../..')
from src.rf_mapping.hook import SizeInspector
from src.rf_mapping.spatial import get_rf_sizes, SpatialIndexConverter
import src.rf_mapping.constants as c
from src.rf_mapping.net import get_truncated_model
from src.rf_mapping.image import (preprocess_img_to_tensor,
                                  preprocess_img_for_plot,
                                  make_box,)


#######################################.#######################################
#                                                                             #
#                               GUIDED_BACKPROP                               #
#                                                                             #
###############################################################################
class GuidedBackprop:
    """
    Generates the guided backpropagation gradient maps for a spatial location
    of a particular unit of a particular layer.
    
    This class was implemented by utkuozbulak (https://github.com/utkuozbulak/
    pytorch-cnn-visualizations) and later modified by Dr. Taekjun Kim to
    include unit specificity and by me to improve generalizability for other
    model architectures.
    """
    def __init__(self, model, layer_index, remove_neg_grad=True):
        self.model = get_truncated_model(model, layer_index)

        self.gradients = None
        self.forward_relu_outputs = []
        self.remove_neg_grad = remove_neg_grad

        self._register_hook_to_first_layer(self.model)
        self._update_relus(self.model)
        
    def _first_hook_function(self, module, grad_in, grad_out):
        self.gradients = grad_in[0]
        # [0] bc we are only interested in one unit at a time, so grad_in
        # will be a tuple of size 1.
        if not isinstance(module, nn.Conv2d):
            warnings.warn("The first layer is not Conv2d.")

    def _forward_hook_function(self, module, ten_in, ten_out):
        """
        Stores results of forward pass.
        """
        if (isinstance(module, nn.ReLU)):
            self.forward_relu_outputs.append(ten_out)

    def _backward_hook_function(self, module, grad_in, grad_out):
        """
        Rectifies gradients.
        """
        if (isinstance(module, nn.ReLU)):
            # print(layer_name)
            target_grad = grad_in[0]
            # [0] bc we are only interested in one unit at a time, so grad_in
            # will be a tuple of size 1.
            
            # Get last forward output of ReLU. Use non-zero values to create
            # a mask of 1's and 0's.
            try:
                corresponding_forward_output = self.forward_relu_outputs[-1]
                corresponding_forward_output[corresponding_forward_output > 0] = 1
                # Rectification (see Springenberg et al. 2015 Figure 1).
                modified_grad_out = corresponding_forward_output * target_grad * (target_grad > 0)
            except:
                # TODO: This is a temporary fix for resnet18().
                # Why this work: the original error (in the 'try' clause) is
                # because, for the shortcut conv layers, the target_grad comes
                # from the ReLU between the residual blocks, while the
                # corresponding_forward_output comes from the ReLU inside the
                # residual block that runs parallel to the shortcut. This is
                # because the network calculates the residual path first, then
                # calculates the shortcut path. So, even though the backprop
                # did not require the residual path, its ReLU is calculated and
                # appended to the self.forward_relu_outputs list.
                del self.forward_relu_outputs[-1]

                corresponding_forward_output = self.forward_relu_outputs[-1]
                corresponding_forward_output[corresponding_forward_output > 0] = 1
                modified_grad_out = corresponding_forward_output * target_grad * (target_grad > 0)

            # Remove last forward output and return.
            del self.forward_relu_outputs[-1]
            return (modified_grad_out,)

    def _register_hook_to_first_layer(self, layer):
        # Skip any container layers.
        while (len(list(layer.children())) != 0):
            layer = list(layer.children())[0]
        layer.register_backward_hook(self._first_hook_function)

    def _update_relus(self, layer):
        """
        Updates ReLU activation functions so that they now:
            1- rectify gradient values so that there's no negative gradients.
            2- store output in forward pass.
        """
        # If layer is not a container, register hook.
        if (len(list(layer.children())) == 0):
            # self.layers.append(layer)  # Keep track of all non-container layers
            layer.register_forward_hook(self._forward_hook_function)
            layer.register_backward_hook(self._backward_hook_function)

        # Otherwise (i.e.,the layer is a container type layer), recurse.
        else:
            for i, sublayer in enumerate(layer.children()):
                self._update_relus(sublayer)

    def generate_gradients(self, img, target_unit, target_spatial_idx):
        """
        Generates the gradient map of the target with respect to the image.

        Parameters
        ----------
        img : numpy.array
            The input image.
        # target_layer : int
        #     The index of the target layer. Note that the indexing here does not
        #     include container layers.
        target_unit : int
            The id of the unit. 
        target_spatial_idx : int or (int, int)
            The spatial index of the target location on the output feature map.
            If only one scalar is provided, this function unravels it into 2D
            index.

        Returns
        -------
        gradient_img : numpy.array
            The gradient map.
        """
        self.model.zero_grad()

        # Forward pass.
        x = preprocess_img_to_tensor(img).clone().detach().requires_grad_(True)
        x = self.model(x)

        # We only care about the gradient w.r.t. the target. 
        if not isinstance(target_spatial_idx, (tuple, list)):
            target_spatial_idx = np.unravel_index(target_spatial_idx, (x.shape[2], x.shape[3]))
        x_target_only = torch.zeros(x.shape, dtype=torch.float).to(c.DEVICE)
        x_target_only[0, target_unit, target_spatial_idx[0], target_spatial_idx[1]] =\
                        x[0, target_unit, target_spatial_idx[0], target_spatial_idx[1]]
        x.backward(gradient=x_target_only)

        if self.gradients is None:
            raise ValueError("Target layer must be Conv2d.")

        # [0] to get rid of the first channel (1, 3, 22x, 22x).
        gradients_img = self.gradients.data.cpu().numpy()[0]
        return gradients_img


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    model = models.alexnet(pretrained = True).to(c.DEVICE)
    # model = models.resnet18(pretrained = True).to(c.DEVICE)

    inspector = SizeInspector(model, (227, 227))
    # inspector.print_summary()

    img_dir = c.REPO_DIR + '/data/imagenet'
    img_idx = 5
    img_path = os.path.join(img_dir, f"{img_idx}.npy")
    img = np.load(img_path)
    img = preprocess_img_for_plot(img)

    layer_idx = 3  # AlexNet = [0, 3, 6, 8, 10] for conv1-5
    unit_idx = 2
    spatial_idx = (5, 5)
    gbp = GuidedBackprop(model, layer_idx, False)
    gbp_map = gbp.generate_gradients(img, unit_idx, spatial_idx)

    plt.figure(figsize=(10,15))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(preprocess_img_for_plot(gbp_map))
    plt.show()


#######################################.#######################################
#                                                                             #
#                             TEST_GUIDED_BACKPROP                            #
#                                                                             #
###############################################################################
def _test_guided_backprop():
    # model = models.resnet18(pretrained=True)
    model = models.alexnet(pretrained=True)
    unit_idx = 1
    image_size = (227, 227)
    layer_indices, rf_sizes = get_rf_sizes(model, image_size, nn.Conv2d)

    output_sizes = SizeInspector(model, image_size).output_sizes
    converter = SpatialIndexConverter(model, image_size)

    img_idx = 5
    img_path = os.path.join(c.REPO_DIR, img_dir, f"{img_idx}.npy")
    img = np.load(img_path)
    # dummy_img = preprocess_img_for_plot(img)
    dummy_img = np.random.rand(3,227,227)
    # dummy_img = np.ones((3, 227, 227)) * 100

    def img_proc(img):
        vmax = img.max()
        vmin = img.min()
        img = (img - vmin)/(vmax - vmin)
        return np.transpose(img, (1,2,0))
    
    for conv_i, layer_idx in enumerate(layer_indices):
        output_size = output_sizes[layer_idx][-1]
        rf_size = rf_sizes[conv_i][0]
        spatial_idx = ((output_size-1)//2, (output_size-1)//2)

        gbp = GuidedBackprop(model, layer_idx)
        gbp_map = gbp.generate_gradients(dummy_img, unit_idx, spatial_idx)
        
        box = converter.convert(spatial_idx, layer_idx, 0, is_forward=False)
        
        plt.figure(figsize=(10, 5))
        plt.suptitle(f"guided backprop of conv{conv_i+1} (RF = {rf_size}, output_size = {output_size})")

        plt.subplot(1,2,1)
        plt.imshow(img_proc(gbp_map))
        print(np.sum(gbp_map < 0))
        plt.title(f"array of ones")
        rect = make_box(box)
        ax = plt.gca()
        ax.add_patch(rect)

        print(gbp_map.shape)
        print(gbp_map.max(), gbp_map.min())
        print(np.sum(gbp_map < 0))
        plt.subplot(1,2,2)
        plt.imshow((np.mean(gbp_map, axis=0) != 0), cmap='gray')
        plt.title(f"binarized (non-zeros = white)")
        rect = make_box(box)
        ax = plt.gca()
        ax.add_patch(rect)
        
        plt.grid()
        plt.show()


if __name__ == "__main__":
    _test_guided_backprop()

"""
Test observations:

(1) When the input image is an array of ones, i.e., 
                dummy_img = torch.ones((1,3,227,227))
    the binarized gradient maps show that non-zero gradients are distributed at
    the top-left corner of the RF and failed to fill the entire RF. Instead,
    the area of the non-zero gradient is roughly equal to the RF of an early
    layer. This suggests that gradient calculations may systemically bias the
    top-left. [Need math proof]

(2) When the input image is natural, the non-zero gradients fill the entire
    RF.

(3) When the RF of the unit is close to or larger than the image size, the left
    and upper edge of the RF is often cropped off. This is because many of
    those layers (e.g. conv16-20 of resnet18) have a feature map size that is
    an even number. For exmample, the center unit of the feature map size of
    8 is (8 - 1)//2 = 3, a little bit to the left of the actual center, and
    since the RF is so large in deeper layers, a little bit off-center will
    result in a big shift from the actual center on the image.
"""