borderownership / src / rf_mapping / hook.py
hook.py
Raw
"""
Functions that register hooks for a variety of purposes.

Tony Fu, Jun 22, 2022
"""
import sys
import copy

import numpy as np
import torch
import torch.nn as nn
from torchvision import models

sys.path.append('../..')
from src.rf_mapping.image import preprocess_img_to_tensor
import src.rf_mapping.constants as c


#######################################.#######################################
#                                                                             #
#                             HOOK FUNCTION BASE                              #
#                                                                             #
###############################################################################
class HookFunctionBase:
    """
    A base class that register a hook function to all specified layer types
    (excluding all container types) in a given model. The child class must
    implement hook_function(). The child class must also call
    self.register_forward_hook_to_layers() by itself.
    """
    def __init__(self, model, layer_types):
        """
        Constructs a HookFunctionBase object.

        Parameters
        ----------
        model : torchvision.models
            The neural network.
        layer_types : tuple of torch.nn.Modules
            A tuple of the layer types you would like to register the forward
            hook to. For example, layer_types = (nn.Conv2d, nn.ReLU) means
            that all the Conv2d and ReLU layers will be registered with the
            forward hook.
        """
        self.model = copy.deepcopy(model)
        self.layer_types = layer_types

    def hook_function(self, module, ten_in, ten_out):
        raise NotImplementedError("Child class of HookFunctionBase must "
                                  "implement hookfunction(self, module, ten_in, ten_out)")

    def register_forward_hook_to_layers(self, layer):
        # If "model" is a leave node and matches the layer_type, register hook.
        if (len(list(layer.children())) == 0):
            if (isinstance(layer, self.layer_types)):
                layer.register_forward_hook(self.hook_function)

        # Otherwise (i.e.,the layer is a container type layer), recurse.
        else:
            for sublayer in layer.children():
                self.register_forward_hook_to_layers(sublayer)


#######################################.#######################################
#                                                                             #
#                             LAYER OUPUT INSPECTOR                           #
#                                                                             #
###############################################################################
class LayerOutputInspector(HookFunctionBase):
    """
    A class that peeks inside the outputs of all the layers with the
    specified layer type, one image at a time.
    """
    def __init__(self, model, layer_types=(nn.Conv2d)):
        super().__init__(model, layer_types)
        self.layer_outputs = []
        self.register_forward_hook_to_layers(self.model)

    def hook_function(self, module, ten_in, ten_out):
        self.layer_outputs.append(ten_out.clone().detach())

    def inspect(self, image):
        """
        Given an image, returns the output activation volumes of all the layers
        of the type <layer_type>.

        Parameters
        ----------
        image : numpy.array or torch.tensor
            Input image, most likely with the dimension: [3, 2xx, 2xx].

        Returns
        -------
        layer_outputs : list of torch.tensors
            Each item is an output activation volume of a target layer.
        """
        if (isinstance(image, np.ndarray)):
            image = preprocess_img_to_tensor(image)
        _ = self.model(image)
        return self.layer_outputs


#######################################.#######################################
#                                                                             #
#                              CONV UNIT COUNTER                              #
#                                                                             #
###############################################################################
class ConvUnitCounter(HookFunctionBase):
    """
    A class that counts the number of unique kernels of each convolutional
    layers of the given model.
    """
    def __init__(self, model):
        super().__init__(model, nn.Module)
        self._layer_counter = 0
        self.layer_indices = []
        self.num_units = [] 
        self.register_forward_hook_to_layers(self.model)

    def hook_function(self, module, ten_in, ten_out):
        if isinstance(module, nn.Conv2d):
            self.layer_indices.append(self._layer_counter)
            self.num_units.append(module.out_channels)
        self._layer_counter += 1

    def count(self):
        """
        Returns
        -------
        layer_indices : [int, ...]
            The indices of nn.Conv2d layers. For torchvision.model.alexnet(),
            this will be [0, 3, 6, 8, 10].
        num_units : [int, ...]
            The count of unique kernels of each convoltional layers.
            
        layer_indices and num_units have the same length, and their elements
        correspond to each other.
        """
        # Forward pass.
        dummy_input = torch.zeros((1, 3, 227, 227)).to(c.DEVICE)
        self.model(dummy_input)
        
        return self.layer_indices, self.num_units
    

if __name__ == "__main__":
    model = models.alexnet()
    counter = ConvUnitCounter(model)
    layer_indices, num_units = counter.count()
    print(layer_indices)
    print(num_units)


#######################################.#######################################
#                                                                             #
#                              CONV MAX INSPECTOR                             #
#                                                                             #
###############################################################################
class ConvMaxInspector(HookFunctionBase):
    """
    A class that get the maximum activations and indices of all unique
    convolutional kernels, one image at a time.
    """
    def __init__(self, model):
        super().__init__(model, nn.Conv2d)
        self.all_max_activations = []
        self.all_max_indices = []
        self.register_forward_hook_to_layers(self.model)

    def hook_function(self, module, ten_in, ten_out):
        layer_max_activations = []
        layer_max_indices = []
        
        for unit in range(ten_out.shape[1]):
            layer_max_activations.append(ten_out[0,unit,:,:].max().item())
            layer_max_indices.append(ten_out[0,unit,:,:].max().item())
            
        self.all_max_activations.append(layer_max_activations)
        self.all_max_indices.append(layer_max_indices)

    def inspect(self, image):
        """
        Given an image, returns the output activation volumes of all the layers
        of the type <layer_type>.

        Parameters
        ----------
        image : numpy.array or torch.tensor
            Input image, most likely with the dimension: [3, 2xx, 2xx].

        Returns
        -------
        layer_outputs : list of torch.tensors
            Each item is an output activation volume of a target layer.
        """
        if (not isinstance(image, torch.Tensor)):
            image = preprocess_img_to_tensor(image)
        _ = self.model(image)
        
        copy_activations = self.all_max_activations[:]
        copy_indices = self.all_max_indices[:]
        
        self.all_max_activations = []
        self.all_max_indices = []
        
        return copy_activations, copy_indices


def top_bottom_N_image_patches(model, layer_type, image_dir, image_names):
    model.eval()
    inspector = LayerOutputInspector(model, layer_type)
    image = np.load(f"{image_dir}/{image_names[0]}")
    layer_outputs = inspector.inspect(image)
    for layer_output in layer_outputs: 
        print(layer_output.max())


if __name__ == '__main__':
    model = models.alexnet(pretrained=True)
    image_dir = c.REPO_DIR + "/data/imagenet"
    image_names = ["0.npy"]
    top_bottom_N_image_patches(model, nn.Conv2d, image_dir, image_names)


#######################################.#######################################
#                                                                             #
#                               SIZE INSPECTOR                                #
#                                                                             #
###############################################################################
class SizeInspector(HookFunctionBase):
    """
    A class that empirically determines the input and output sizes of all
    layers. This class determines the indexing convention of the layers. The
    indexing follows the flow of data through the model and excludes all
    container-type layers. For example, the indexing of
    torchvision.models.alexnet() is:

          no. | layer name
        ------+-----------
           0  |   Conv1
           1  |   ReLU1
           2  |  MaxPool1
           3  |   Conv2
           4  |   ReLU2
             ...
          19  |   ReLU7
          20  |  Linear3

    To get the indexing information for any arbitrary model, use the syntax:
        inspector = SizeInspector(model, image_size)
        inspector.print_summary()
    """
    def __init__(self, model, image_shape):
        super().__init__(model, layer_types=(torch.nn.Module))
        self.image_shape = image_shape
        self.layers = []
        self.input_sizes = []
        self.output_sizes = []
        self.register_forward_hook_to_layers(self.model)
        
        self.model(torch.zeros((1,3,*image_shape)).to(c.DEVICE))

    def hook_function(self, module, ten_in, ten_out):
        if (isinstance(module, self.layer_types)):
            self.layers.append(module)
            self.input_sizes.append(ten_in[0].shape[1:])
            self.output_sizes.append(ten_out.shape[1:])

    def print_summary(self):
        for i, layer in enumerate(self.layers):
            print("---------------------------------------------------------")
            print(f"  layer no.{i}: {layer}")
            try:
                print(f"  input size: ({self.input_sizes[i][0]}, "\
                      f"{self.input_sizes[i][1]}, {self.input_sizes[i][2]})")
                print(f" output size: ({self.output_sizes[i][0]}, "
                      f"{self.output_sizes[i][1]}, {self.output_sizes[i][2]})")
            except:
                print(" This layer is not 2D.")


if __name__ == '__main__':
    model = models.alexnet()
    inspector = SizeInspector(model, (227, 227))
    inspector.print_summary()