borderownership / src / not_used / checkcnn.py
checkcnn.py
Raw
# -*- coding: utf-8 -*-
"""checkcnn.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/13zkuKB_Hwr20k2o9A3lMquRNTZ7WGOmA

# CNN Analysis Attributes Summarizer

Tony Fu

tonyfu97@uw.edu

Bair Lab, Univesity of Washington

June, 2021

Functions for analyzing the PyTorch implementations of convolutional neural
networks (CNN). 

The most useful function: get_info_of_all_layers(model, xn, yn)

This module assume the kernels are square, and the x- and y- strides and padding 
are the same for each kernel. The module, however, does not assume the input 
image size is square.
"""

VERSION = 1.0

import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import urllib.request
from torch.utils.data import DataLoader, SequentialSampler
from torchvision import models, transforms, datasets
from collections import OrderedDict
from PIL import Image

if __name__ == '__main__':
    model = models.alexnet(pretrained=True)
    model.children

"""Classes that are used as containers for layer attributes."""

class Layer:
    """Base class for Conv2d and MaxPool2d.

    Paramters
    ---------
    model: torchvision.models.name_of_architecture
        The PyTorch implementation of a convolutional neural network.
    layer : torch.nn.Sequential object
        The raw info about the layer.
    layer_index : int
        The index of the layer in the Sequential container.
    input_shape : torch.Size([num_img, num_kernels, y_in, x_in])
        The input size for this layer. (Not the absolute size in pixels)
    output_shape : torch.Size([num_img, num_kernels, y_out, x_out])
        The output size for this layer. (Not the absolute size in pixels)
    RF_size : int
        The absolute side-length (in pixels) of the receptive field of an unit 
        of a layer. Use get_RF_sizes() to get the values.
    
    Other Attributes
    ----------------
    output_midpoint : tuple of 2 int
        The coordinate of the 'middle' grid on the output layer. For example, 
        the output_midpoint for an output_shape of (6,5) will be (3,3).
    middle_pixels: tuple of 2 int
        The middle unit's RF on the image coordinates (x,y), where x is the 
        horizontal axis pointing to the right, and y is the vertical axis 
        pointing downward. 
    """
    def __init__(self, model, layer, layer_index, input_shape, output_shape, RF_size):
        self.layer = layer
        self.layer_index = layer_index
        self.input_shape = list(input_shape)
        self.output_shape = list(output_shape)
        self.kernel_size = layer.kernel_size
        self.stride = layer.stride
        self.padding = layer.padding
        self.RF_size = RF_size
        self.activation_mean = None
        self.activation_std = None

        # e.g. the output_midpoint for an output_shape of (6,5) will be (3,3).
        self._output_mid_x = round((output_shape[3] - 0.9)/2) # 0.9 instead 1 to prevent rounding error
        self._output_mid_y = round((output_shape[2] - 0.9)/2)
        self.output_midpoint = (self._output_mid_x, self._output_mid_y)
        self.middle_pixel = get_RF_center(model, layer_index, self._output_mid_x, self._output_mid_y)


class Conv2d(Layer):
    """Layer attribute container for a Conv2d layer."""
    def __init__(self, model, layer, layer_index, input_shape, output_shape, RF_size):
        Layer.__init__(self, model, layer, layer_index, input_shape, output_shape, RF_size)
        self.num_of_kernels = layer.out_channels
        # Assuming square kernel.
        self.kernel_size = layer.kernel_size[0]
        self.stride = layer.stride[0]
        self.padding = layer.padding[0]
        self.kernel = None


class MaxPool2d(Layer):
    """Layer attribute container for a MaxPool2d layer."""
    def __init__(self, model, layer, layer_index, input_shape, output_shape, RF_size):
        Layer.__init__(self, model, layer, layer_index, input_shape, output_shape, RF_size)
        self.dilation = layer.dilation
        if (self.dilation != 1):
            raise Exception(f"The definition of receptive field is incapatable \
                             with a MaxPool2d dilation of {self.dilation}.")


class ReLU:
    """Layer attribute container for a ReLU layer."""
    def __init__(self, layer, layer_index, input_shape):
        self.layer = layer
        self.layer_index = layer_index
        self.inplace = layer.inplace
        self.input_shape = self.output_shape = list(input_shape)
        self.activation_mean = None
        self.activation_std = None

def get_conv_info(layer, model):
    """Gets the information about a convolutional layer of the model CNN.

    Parameters
    ----------
    layer : int
        The numbering of the convolutional layer, starting from 1. 
    model : torchvision.models.name_of_cnn
        The  convolutional neural networks model.
    
    Returns
    -------
    A list containg the following item in order:
        0. The original index of the conv layer in the nn.Sequential container
        1. The number of input channels
        2. The number of output channels
        3. The size of the kernel, as a tuple of length 2
        4. (Relative) Stride: the number of pixel shift over the input, as a tuple of length 2
        5. (Relative) Padding: the number of 0-pixels added to the input at the edges
    """
    conv_counter = 0 # Keep track of the convoluitonal layers
    # Get the nn.Sequential container. (It may be in a nested loop or not.)
    if hasattr(list(model.children())[0], '__iter__'):
        model_layer_list = list(model.children())[0]
    else:
        model_layer_list = list(model.children())

    for i, info in enumerate(model_layer_list):          
        if isinstance(info, nn.Conv2d):
            conv_counter += 1
            if (layer == conv_counter):
                in_channels = model_layer_list[i].in_channels
                out_channels = model_layer_list[i].out_channels
                kernel_size = model_layer_list[i].kernel_size
                stride = model_layer_list[i].stride
                padding = model_layer_list[i].padding
                print(f"Conv{layer} was found in the {i}th layer.")
                return [i, in_channels, out_channels, kernel_size, stride, padding]

    # If no such conv layer exist, print:
    print(f"There is no Conv{layer}.")


if __name__ == '__main__':
    get_conv_info(2, models.vgg16())

def get_RF_sizes(model):
    """Gives the absolute sizes (in pixels) of the receptive fields of the units
    in the convolutional layers. 

    Parameters
    ----------
    model: torchvision.models.name_of_architecture
        The PyTorch implementation of a convolutional neural network.
    
    Returns
    -------
    An OrderedDict containing the layers' names as the keys, and the (absolute)
    size of the receptive field (RF) as the values. 
    """
    RF_sizes = OrderedDict()
    conv_counter = 0  # the number of conv layers encounted so far
    for i, layer in enumerate(list(model.children())[0]):

        # We only care about MaxPool and Conv layers.
        if isinstance(layer, nn.MaxPool2d):
            kernel_size = layer.kernel_size
            stride = layer.stride
            key = "MaxPool" + str(conv_counter)
        elif isinstance(layer, nn.Conv2d):
            kernel_size, _ = layer.kernel_size # assume same x and y sizes
            stride, _ = layer.stride
            conv_counter += 1
            key = "Conv" + str(conv_counter)
        else:
            continue
        
        # Update (or initialize) RF size
        if (i == 0):
            abs_RF_size = kernel_size  # absolute receptive field size
            abs_stride = 0  # absolute stride in pixels
        abs_RF_size = abs_RF_size + (kernel_size - 1)*abs_stride
        RF_sizes[key] = abs_RF_size

        # Update absolute stride
        if (i == 0):
            abs_stride = stride
        else: 
            abs_stride = abs_stride*stride
    
    return RF_sizes


def _test_get_RF_sizes():
    model = models.alexnet()
    RF_sizes = get_RF_sizes(model)
    assert list(RF_sizes.values()) == [11, 19, 51, 67, 99, 131, 163, 195],\
                                        "_test_get_RF_sizes() failed"
    for key, value in RF_sizes.items():
        print(f"{key:>9} RF size: {value}")
    
if __name__ == '__main__':
    _test_get_RF_sizes()

def get_RF_center(model, layer_index, x, y):
    """Calculates the center of the unit's receptive field (RF) on the input 
    image. 

    Parameter
    ---------
    model : torchvision.models.name_of_cnn
        The PyTorch implementation of a convolutional neural network.
    layer_index : int
        The index of the layer of interest, starting from 0.
    x : int
        The horizontal coordinate of the unit on the output layer.
    y : int
        The vertical coordinate of the unit on the output layer. 

    Returns
    -------
    The center of the unit's RF on the input image: (x,y) in pixels. 

    Notes
    -----
    You need this function because the center of a RF tends to drift to the 
    bottom-right as you get deeper into the CNN. Use get_RF_sizes(model) to find
    the size of the RF. 
    """
    # Get the layer up till the layer of interest, reverse to project backward
    layers = list(model.children())[0][layer_index+1::-1]

    for layer in layers:

        # Skip the layers that are not Conv2d and MaxPool2d (skip those whose 
        # input size is the same as the output size.)
        if isinstance(layer, nn.Conv2d):
            padding, _ = layer.padding
            stride, _ = layer.stride
            kernel_size, _ = layer.kernel_size
        elif isinstance(layer, nn.MaxPool2d):
            padding = layer.padding 
            stride = layer.stride
            kernel_size = layer.kernel_size
        else:
            continue

        # Map the coordinates of this layer to the coordinates a layer before. 
        x = _map_to_earlier_layer(x, stride, kernel_size, padding)
        y = _map_to_earlier_layer(y, stride, kernel_size, padding)

    return (x, y)


def _map_to_earlier_layer(x, stride, kernel_size, padding):
    """Internal function. Map the coordinates of this layer to the coordinates a 
    layer before. This function was designed to be used recursively to update x
    all the way to the input image coordinates. It returns the input coordinate
    of the input layer (a layer before this layer).
    """
    return (x*stride) + (kernel_size-1)/2 - padding

def get_activation_std_mean(pretrained_model):
    """
    """
    device = ('mps' if torch.has_mps else 'cpu')
    pretrained_model.to(device=device)
    pretrained_model.eval()
    xn = yn = 227
    layers = OrderedDict()
    conv_counter = 0  # the number of conv layers encounted so far
    RF_sizes = get_RF_sizes(pretrained_model)  # OrdDict containing RF_sizes of Conv and MaxPool

    # Download, transform, load data
    transform = transforms.Compose([transforms.Resize([yn,xn]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std =[0.229, 0.224, 0.225])])
    dataset = datasets.CIFAR10(root="./test/", train=False, transform=transform, download=True)
    seq_sampler = SequentialSampler(range(64))
    data_loader = DataLoader(dataset, batch_size=64, sampler=seq_sampler)
    dataiter = iter(data_loader)
    test_stimulus_in, _ = dataiter.next()
    test_stimulus_in = test_stimulus_in.to(device)

    for i, layer in enumerate(list(pretrained_model.children())[0]):
        # We only care about Conv2d, MaxPool2d, and ReLU layers.
        if isinstance(layer, nn.Conv2d):
            conv_counter += 1
            key = "Conv" + str(conv_counter) # used for both RF_sizes and layers OrdDicts.
        elif isinstance(layer, nn.MaxPool2d):
            key = "MaxPool" + str(conv_counter)
        elif isinstance(layer, nn.ReLU):
            key = "ReLU" + str(conv_counter) 
        else:
            test_stimulus_in = test_stimulus_out
            continue

        test_stimulus_out = layer.forward(test_stimulus_in).detach().clone()
        std = torch.std_mean(test_stimulus_out,(0,1,2,3))[0].item()
        mean = torch.std_mean(test_stimulus_out,(0,1,2,3))[1].item()
        layers[key] = (std, mean)
        test_stimulus_in = test_stimulus_out
    
    return layers  

# if __name__ == '__main__':
#     act_std_mean = get_activation_std_mean(models.alexnet(pretrained=True))
#     for layer, std_mean in act_std_mean.items():
#         std, mean = std_mean
#         print(f"{layer:9} mean: {round(mean,2):5}, std: {round(std,2):5}.")

def get_info_of_all_layers(model, pretrained_model, xn=227, yn=227):
    """Get the information of the Conv2d, MaxPool2d, and ReLU layers of a 
    PyTorch implementation of a convolutional neural network (CNN).

    Parameters
    ----------
    model : torchvision.models.name_of_CNN
        The PyTorch implementation of a CNN. Must NOT be pretrained (the input
        shape would not be (...,yn, xn) otherwise).
    pretrained_model : torchvision.models.name_of_CNN(pretrained=True)
        The PyTorch implementation of a CNN. Must be pretrained to get 
        activations.
    xn : int, default=227
        The horizontal width (in pixels) of the input image 
    yn : int, default=227
        The vertical height (in pixels) of the input image

    Returns
    -------
    An OrderedDict containing the names of the layer as the keys, and the 
    Layer objects as the values, each of which has the following attributes
    (the included set of attributes depends on the type of the layer):
        a. layer : torch.nn.Sequential object
            The raw info about the layer.
        b. layer_index : int
            The index of the layer in the Sequential container.
        c. input_shape : torch.Size([num_img, num_kernels, kernel_size_x, kernel_size_y])
            The input size for this layer. (Not the absolute size in pixels)
        d. output_shape : torch.Size([num_img, num_kernels, kernel_size_x, kernel_size_y])
            The output size for this layer. (Not the absolute size in pixels)
        e. kernel_size : int
            The size of the each individual kernel. (Not the absolute size in pixels)
        f. stride : int
            The distance between two adjacent kernels. (Not the absolute stride in pixels)
        g. padding : int
            The extra length added to the both sides of the input. Only one-sided
            value is given. (Not the absolute padding in pixels)
        h. RF_size : int
            The absolute side-length (in pixels) of the receptive field of an 
            unit of a layer.
        i. num_kernel : int
            The number of kernels (i.e., unique units) of a layer.
        j. dilation : int 
            A quantity describing the coverage of the MaxPool kernels.
        k. inplace : bool
            inplace=True means that ReLU is applied directly to the input 
            without instantiating an new object.
        l. middle_pixel : tuple of 2 int
            (x, y) coordinates of the center of the RF of the center unit.
    
    Notes
    -----
    The function assuming the kernels are square, and the strides and paddings 
    are equal in x- and y-directions. 
    """
    model.eval()
    layers = OrderedDict()
    conv_counter = 0  # the number of conv layers encounted so far
    RF_sizes = get_RF_sizes(model)  # OrdDict containing RF_sizes of Conv and MaxPool
    # activation_std_mean = get_activation_std_mean(pretrained_model)
    test_stimulus_in = torch.zeros([1, 3, yn, xn])

    for i, child in enumerate(list(pretrained_model.children())[0]):
        test_stimulus_out = child.forward(test_stimulus_in).detach().clone()

        # We only care about Conv2d, MaxPool2d, and ReLU layers.
        if isinstance(child, nn.Conv2d):
            conv_counter += 1
            key = "Conv" + str(conv_counter) # used for both RF_sizes and layers OrdDicts.
            layer = Conv2d(model, child, i, test_stimulus_in.shape, test_stimulus_out.shape, RF_sizes[key])
            layer.kernel = child.weight
        elif isinstance(child, nn.MaxPool2d):
            key = "MaxPool" + str(conv_counter)
            layer = MaxPool2d(model, child, i, test_stimulus_in.shape, test_stimulus_out.shape, RF_sizes[key])
        elif isinstance(child, nn.ReLU):
            key = "ReLU" + str(conv_counter)
            layer = ReLU(child, i, test_stimulus_in.shape)
        else:
            # print(f"skipped layer: {child}")
            test_stimulus_in = test_stimulus_out
            continue

        # layer.activation_std, layer.activation_mean = activation_std_mean[key]
        layers[key] = layer
        test_stimulus_in = test_stimulus_out
    
    return layers

def _test_RF(model, pretrained_model, xn=227, yn=227):
    """Internal function. The message: "The RF in {layer_name} is mapped 
    incorrectly." indicates incorrect RF mapping. The RF field is bounded by:
        left_bound  = math.floor(midpoint_x - layer.RF_size/2)
        up_bound    = math.floor(midpoint_y - layer.RF_size/2)
        right_bound = math.ceil( midpoint_x + layer.RF_size/2)
        low_bound   = math.ceil( midpoint_y + layer.RF_size/2)
    """
    # The choice of testing images is arbitrary. Use any images you like.
    url_0 = "https://upload.wikimedia.org/wikipedia/commons/1/12/University_of_Washington%2C_February_2014_-4.JPG"
    url_1 = "https://upload.wikimedia.org/wikipedia/commons/0/09/Seattle_Nov_2014_Rainy_Day_Space_Needle_%2815774343216%29.jpg"
    url_2 = "https://upload.wikimedia.org/wikipedia/commons/f/f5/Siberian_Husky_-_Mika.jpg"
    urls = [url_0, url_1, url_2]

    layers = get_info_of_all_layers(model, pretrained_model, xn, yn)

    transform = transforms.Compose([transforms.Resize([yn,xn]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std =[0.229, 0.224, 0.225])])
    
    for i, url in enumerate(urls):
        # Load, transform, and plot the images
        urllib.request.urlretrieve(url, "reference_stim.jpg")
        reference_stim = Image.open("reference_stim.jpg")
        reference_stim_tensor = torch.unsqueeze(transform(reference_stim), dim=0)
        plt.figure()
        plt.imshow(reference_stim)
        plt.title(f"Reference Stimulus #{i}")

        for layer_name, layer in layers.items():
            # Only test for Conv2d and MaxPool2d layers
            if (not isinstance(layer, Conv2d)) and (not isinstance(layer, MaxPool2d)):
                continue

            # Get the activation for the reference stimulus
            features = model.features[:layer.layer_index+1] # 'features' = an nn.Sequential object
            reference_activation = features.forward(reference_stim_tensor).detach().clone()

            # Make a boolean mask for the receptive field (RF)
            midpoint_x, midpoint_y = layer.middle_pixel
            left_bound  = math.floor(midpoint_x - layer.RF_size/2)
            up_bound    = math.floor(midpoint_y - layer.RF_size/2)
            right_bound = math.ceil(midpoint_x + layer.RF_size/2)
            low_bound   = math.ceil(midpoint_y + layer.RF_size/2)
            RF_mask = torch.full([1, 3, yn, xn], False)
            RF_mask[:, :, left_bound:right_bound+1, up_bound:low_bound+1] = True

            # For every reference stimulus, replace the pixels outside the RF 
            # with random values. Calculate the rate of identical activation.
            exp_stimulus = torch.rand([1, 3, yn, xn])
            exp_stimulus[RF_mask] = reference_stim_tensor.detach().clone()[RF_mask]
            activation = features.forward(exp_stimulus).detach().clone()
            
            # check if the middle unit response the same to the reference and 
            # the experimental stimulus.
            layer_midpt_x, layer_midpt_y = layer.output_midpoint
            result = torch.eq(reference_activation[:,:,layer_midpt_y,layer_midpt_x], activation[:,:,layer_midpt_y,layer_midpt_x]).numpy()
            if (np.sum(result) != result.shape[1]):
                print(f"The RF in {layer_name} is mapped incorrectly.")
            plt.figure(figsize=(3,3))
            plt.imshow(np.transpose(torch.squeeze(exp_stimulus),(1,2,0)))
            plt.title(f"In {layer_name}, {np.sum(result)} out of {result.shape[1]} kernels have identical activations")  
            plt.show()

if __name__ == '__main__':
    #_test_RF(models.vgg16(pretrained=False), models.vgg16(pretrained=True), 227, 301) # arbitrary choices of xn and yn
    _test_RF(models.alexnet(pretrained=False), models.alexnet(pretrained=True), 227, 227)

if __name__ == '__main__':
    xn = yn = 227
    layers = get_info_of_all_layers(models.alexnet(), models.alexnet(pretrained=True), xn, yn)
    layer_indices = [0, 2, 3, 5, 6, 8, 10, 12]
    layer_names = ["Conv1", "MaxPool1", "Conv2", "MaxPool2", "Conv3", "Conv4", "Conv5", "MaxPool5"]
    print(f"input image size: {xn} x {yn}")
    for i, string in zip(layer_indices, layer_names):
        midpoint = int((layers[string].output_shape[3] + 0.9)/2)
        xy = get_RF_center(models.alexnet(), i, midpoint, midpoint)
        print(f"{string:9}: input shape {layers[string].input_shape}, and output shape {layers[string].output_shape}")
        print(f"{string:9}: middle unit at ({midpoint:2}, {midpoint:2}) corresponding to a center at {xy}")