borderownership / src / not_used / artiphysiology.py
artiphysiology.py
Raw
# -*- coding: utf-8 -*-
"""artiphysiology.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1UvfdrDKnAwJCzBczL3BU-sxAu9HVLMaT

# Artiphysiology Tool Box

Tony Fu (Dec, 2021)

Bair Lab, Univserity of Washington

tonyfu97@uw.edu

**References**:

[1] Pospisil DA, Pasupathy A, Bair W. ['Artiphysiology' reveals V4-like shape tuning in a deep network trained for image classification](https://pubmed.ncbi.nlm.nih.gov/30570484/). Elife. 2018 Dec 20;7:e38242. doi: 10.7554/eLife.38242. PMID: 30570484; PMCID: PMC6335056.


[2] Zhou H, Friedman HS, von der Heydt R. ['Coding of border ownership in monkey visual cortex'](https://pubmed.ncbi.nlm.nih.gov/10964965/). J Neurosci. 2000 Sep 1;20(17):6594-611. doi: 10.1523/JNEUROSCI.20-17-06594.2000. PMID: 10964965; PMCID: PMC4784717.

# Preliminaries 
imports and data-loading (for testing purposes)
"""

import os
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import patches 
from matplotlib.gridspec import GridSpec
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
from torchvision import models
from torch.autograd import Variable
from tqdm import tqdm 
from pathlib import Path
from collections import OrderedDict, Counter
from PIL import Image

if __name__ == '__main__':
  # Mount google drive in google colab
  from google.colab import drive
  drive.mount('/content/gdrive', force_remount=True)

  import sys
  sys.path.insert(0, '/content/gdrive/My Drive/Colab Notebooks/Border Ownership Research')

# Import the customized modules
import boshape as bo     # generates Border-ownership shapes
import checkcnn as ccnn  # returns summaries of CNN attributes

if __name__ == '__main__':
  # Load the files containing the receptive field (RF) statistics of conv layers.
  dir_alex = "/content/gdrive/My Drive/Colab Notebooks/Border Ownership Research/AlexNet Stats/"
  dir_vgg16 = "/content/gdrive/My Drive/Colab Notebooks/Border Ownership Research/VGG16 Stats/"

  # Load the receptive field (RF) statistics of conv layers 2 to 5 of AlexNet.
  unit_stats_alex = {}
  ori_stats_alex = {}
  for i in range(2, 6):
      exec(f"rfs_alex_conv{i} = np.load(dir_alex + 'n01_stat_conv{i}_mrf_4a.npy')")
      # Insert a column contaning the indices of the units.
      exec(f"unit_stats_alex['Conv{i}'] = np.insert(rfs_alex_conv{i}, 0, range(0, len(rfs_alex_conv{i})), axis=1)")
      # Load another file containing the unit's preferred orientation
      exec(f"ori_stats_alex['Conv{i}'] = np.load(dir_alex + 'n01_stat_conv{i}_siz_1.npy')")

  # Load the receptive field (RF) statistics of conv layers 7 to 12 of VGG16.
  unit_stats_vgg16 = {}
  ori_stats_vgg16 = {}
  for i in range(7,13):
      exec(f"rfs_vgg16_conv{i} = np.load(dir_vgg16 + 'n03_stat_conv{i}_mrf_4a.npy')")
      # Insert a column contaning the indices of the units.
      exec(f"unit_stats_vgg16['Conv{i}'] = np.insert(rfs_vgg16_conv{i}, 0, range(0, len(rfs_vgg16_conv{i})), axis=1)")
      # Load another file containing the unit's preferred orientation
      exec(f"ori_stats_vgg16['Conv{i}'] = np.load(dir_vgg16 + 'n03_stat_conv{i}_sine_2.npy')")

"""Each row in <rfs_network_convi> contains 6 values in this order:
[0] unit_idx - the unit's number starting from 0
[1] R_max    - the maximum response across the RF mapping stimuli.
[2] COM_0    - center of mass (pix) along vertical image axis (center=0)
[3] COM_1    - center of mass (pix) along horizontal image axis (center=0)
[4] Rad_90   - the radius (pix) that includes 90% of the MRF map for a circle 
              around the COM point. The value "-1" indicates that the RF map 
              failed, for example, if responses to all stimuli were less than
              or equal to zero.
[5] f_nat    - ratio of R_max to the average of the top 10 responses to natural images
"""

"""Each row in <rfs_network_convi_ori> contains 10 values in this order:
[0] unit_idx - the unit's number starting from 0
[1] R_max    - the maximum raw response across all sinusoidal grating stimuli.
[2] T_max    - the maximum tuning-curve response (either DC or FI) over ori, SF, 
               and size after collapsing over phase.
[3] ori-max  - orientation at R_max (deg)
[4] sf_max   - spatial frequency at R_max (cyc/pix)
[5] size_max - diatmeter at R_max (pix)
[6] phase_max- spatial phase at R_max (pdeg)
[7] mod_idx  - modulation index: F1/(DC - BR). Values >= 1 indicate simple cells,
               <1 indicates complex cells.
[8] sup_idx  - suppression index: (T_max - T_large)/(T_max - BR)
[9] BR       - baseline response to blank (all-zero) input
"""

"""# Pre-processing"""

def get_unit_info(unit_stats, unit_num, print_info=False):
    """Prints and gives the receptive field (RF) statistics of the given unit.
    
    Parameters
    ----------
    unit_stats : n x 6 numpy array
        Each row contains six RF stats (i.e., unit_idx, R_max, CM_y, CM_x, RF_r, 
        and f_nat) about one unit in the convolutional layer. The first dimension
        n is the number of unique units in the convolutional layer.
    unit_num : int
        The unit's index in the layer.
    
    Returns
    -------
    The row in unit_stats containing the 6 values stats about the unit.
    """
    unit_info = [row.tolist() for row in unit_stats if row[0]==unit_num][0]
    
    if (print_info == True):
        txt = f"""For unit #{unit_num} in this layer:
        0. unit_idx = {unit_info[0]}
        1. R_max = {unit_info[1]}
        2. CM_y  = {unit_info[2]}
        3. CM_x  = {unit_info[3]}
        4. RF_r  = {unit_info[4]}
        5. f_nat = {unit_info[5]}"""
        print(txt)

    return unit_info


# Test
if __name__ == '__main__':
    get_unit_info(unit_stats_alex['Conv2'], 2, print_info=True)

def clean_unit_stats(unit_stats, middle_pixel, RF_size, f_nat_threshold=0.2,
                     distance_threshold=15, xn=227, yn=227):
    """Suggests a list of units (i.e., convolutional kernels) to be deleted.

    Parameters
    ----------
    unit_stats : n x 6 numpy array
        Each row contains six RF stats (i.e., unit_idx, R_max, CM_y, CM_x, RF_r, 
        and f_nat) about one unit in the convolutional layer. The first dimension
        n is the number of unique units in the convolutional layer.
    middle_pixel : tuple of 2 int
        The image coordinate (x, y) of the Conv2d (or MaxPool2d) layer if the 
        unit is back-projected onto the first input layer, that is, the center 
        of the theoretical maximal RF.
    RF_size : int
        The theoretical maximal size of the unit's RF of that layer. Can be 
        calculated using the function 'checkcnn.get_RF_sizes(model)'. 
    f_nat_threshold: float, default=0.4
        The f_nat threshold above which to keep the units for analysis
    distance_threshold : int, default=15
        If <RF_size>/2 - <RF_COM_from_middle_pixel> is smaller than this number,
        the unit's RF center is too off-centered. The larger the threshold, 
        the more units will be removed. This is to ensure that the RF captured
        are acutally mathmetically possible because the units of CNN kernels
        have well-defined theoretical RFs. 
    xn : int, default=227
        The horizontal width (in pixels) of the input image 
    yn : int, default=227
        The vertical height (in pixels) of the input image

    Returns
    -------
    An sorted list (ascending, no duplicates) of indices of the units that are 
    suggested to exclude from analyses.
    """
    rows_to_delete_f_nat = []
    rows_to_delete_failed_RF = []
    rows_to_delete_too_big = []
    rows_to_delete_too_far = []
    for row in unit_stats:
        # Remove units with f_nat lower than threshold
        if (row[5] < f_nat_threshold):
            rows_to_delete_f_nat.append(int(row[0]))

        # Remove units with failed RFs
        if (math.isclose(row[4], -1)):
            rows_to_delete_failed_RF.append(int(row[0]))

        # Remove units with a RF_radius (determined empirically) larger than the
        # theoretical RF_size.
        if (row[4] > RF_size*0.75):
            rows_to_delete_too_big.append(int(row[0]))
        
        # Remove units with unrealistic RF Center of Mass (too close/outside the
        # mathematically-possible theoretical maximum RF).
        distance_x = abs((row[3] + xn/2) - middle_pixel[0])
        distance_y = abs((row[2] + yn/2) - middle_pixel[1])
        max_distance = max(distance_x, distance_y)
        if (RF_size/2 - max_distance < distance_threshold):
            rows_to_delete_too_far.append(int(row[0]))

    print(f"Units deleted due to low f_nat: {rows_to_delete_f_nat}.") 
    print(f"Units deleted due to failed RFs: {rows_to_delete_failed_RF}.")
    print(f"Units deleted due to impossibly large RFs: {rows_to_delete_too_big}.")
    print(f"Units deleted due to impossible RF center: {rows_to_delete_too_far}.")
    units_deleted = sorted(list(set(rows_to_delete_f_nat + rows_to_delete_failed_RF \
                         + rows_to_delete_too_big + rows_to_delete_too_far)))
    units_left = sorted(list(set(unit_stats[:,0].tolist()) - set(units_deleted)))
    units_left_ints = [int(unit) for unit in units_left]
    print(f"Units left: {units_left_ints}")
    return units_deleted


# Test
if __name__ == '__main__':
  layer_name = 'Conv2'
  xn = yn = 227
  model = models.alexnet()
  pretrained_model = models.alexnet(pretrained=True)
  unit_stats = unit_stats_alex[layer_name]
  ori_stats = ori_stats_alex[layer_name][:,3] * math.pi / 180
  alexnet_layer_dict = ccnn.get_info_of_all_layers(model, pretrained_model, xn, yn)
  layer_idx = alexnet_layer_dict[layer_name].layer_index
  middle_pixel = alexnet_layer_dict[layer_name].middle_pixel
  RF_size = alexnet_layer_dict[layer_name].RF_size
  rows_to_delete = clean_unit_stats(unit_stats, middle_pixel, RF_size)

  existing_rows = set((unit_stats[:,0] + 0.1).astype(int)) # add 0.1 to prevent rounding error
  if (existing_rows.issuperset(set(rows_to_delete))): # check to prevent repeated deletion
      unit_stats = np.delete(unit_stats, rows_to_delete, axis=0)
      ori_stats  = np.delete(ori_stats , rows_to_delete, axis=0)

def get_max_shape_size(RF_size, middle_pixel, RF_x, RF_y, ori=math.pi/4):
    """Find the maximum allowable stimulus size for the unit with the given
    RF field coordinate and orientation. We need this value because the stimulus
    shape must be bounded by the theoretical maximal RF. 
    
    If the stimuli exceed the theoretical maximal RF, it is mathematically 
    impossible for the unit to know what is the figure and what is the
    background, and any discriminations as a result of the overly large
    stimuli are then due to selectivities other than border-ownership. So, do 
    use this function to deterine the maximal size of the stimuli.

    Parameters
    ----------
    RF_size : int
        The theoretical maximal size of the unit's RF of that layer. Can be 
        calculated using the function 'checkcnn.get_RF_sizes(model)'. 
    middle_pixel : tuple of 2 int
        The image coordinate (x, y) of the Conv2d (or MaxPool2d) layer if the 
        unit is back-projected onto the first input layer, that is, the center 
        of the theoretical maximal RF.
    RF_x : int 
        The center of mass (COM) of the RF on the image coordinate (i.e., the 
        origin is at the top-left corner). Must be positive. The RF here is not
        theoretical but rather determined empirically with oriented bar(s). Use 
        the RF_x that is the farthest from middle_pixel[0].
    RF_y : int 
        The center of mass (COM) of the RF on the image coordinate (i.e., the 
        origin is at the top-left corner). Must be positive. The RF here is not
        theoretical but rather determined empirically with oriented bar(s). Use 
        the RF_y that is the farthest from middle_pixel[1].
    ori : int, default=math.pi/4
        The preferred orientation of the unit in radians. Determined empirically
        with oriented bar(s) or sinusoidal grating. To be safe, ori is by 
        default pi/4, giving the smallest maximum allowable stimulus size.

    Returns
    -------
    The maximum allowable stimulus size for the unit (an integer).
    """
    TOO_SMALL = 10
    if (RF_x < 0 or RF_y < 0):
        raise ValueError("RF_x and RF_y must be positive.")
    elif (ori > math.pi):
        raise ValueError("ori must be from 0 to pi.")
    
    max_shape_size_x = RF_size - 2*abs(middle_pixel[0] - RF_x)
    max_shape_size_y = RF_size - 2*abs(middle_pixel[1] - RF_y)
    max_shape_size = min(max_shape_size_x, max_shape_size_y)
    # if (max_shape_size_x < TOO_SMALL or max_shape_size_y < TOO_SMALL):
    #     raise ValueError("RF_x or RF_y is too far away from the middle_pixels. Remove outliers before trying again.")
   
    if (ori == 0 or ori == math.pi/2 or ori == math.pi):
        return math.floor(max_shape_size)
    elif (0 < ori < math.pi/2):
        # Imagine a rotated square trapped inside a larger square with the side 
        # length <max_shape_size>.
        return math.floor(max_shape_size / (math.sin(ori) + math.cos(ori)))
    

# Test
if __name__ == '__main__':
    ans = get_max_shape_size(50, (40, 40), 35, 35)
    assert(ans == 28)

"""# Get Responses"""

def get_activation_maps(stimulus_dict, layer_idx, model=models.alexnet(pretrained=True)):
    """Presents stimuli to a convolutional layer of a CNN and output the 
    activation maps of the units (a.k.a. kernels/filters) in the stimulus_dict.

    Parameters
    ----------
    stimulus_dict : dictionary of {int: boshape.Object}
        The dictionary generated by one of the functions of the boshape.py 
        module. It contains the unit_idx as the keys, and the respective 
        boshape.ObjectName objects as values. The stimli inside the 
        boshape.ObjectName have NOT been normalized yet.
    layer_idx : int
        The index of the layer, starting from 0. For example, 'Conv2' of AlexNet 
        has a layer_idx of 3.
    model : torchvision.models.name_of_cnn, default:models.alexnet(pretrained=True)
        The pretrained convolutional neural networks model that corresponds to 
        the stimulus_dict. The default model is the PyTorch implementation of 
        the 2014 "simplified" AlexNet (not the 2012 one) from the paper "One 
        weird trick for parallelizing convolutional neural networks" by Alex 
        Krizhevsky.
    
    Returns
    -------
    A dictionary with the keys being the indices of the units, and the values
    being the corresponding feature maps (a.k.a. activation maps).
    """
    activation_maps = {}
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    for unit_idx in stimulus_dict.keys():
        # Get and normailize the stimuli made specifically for the unit (a.k.a. kernel/filter)
        stimulus_set = torch.from_numpy(stimulus_dict[unit_idx].copy()).type('torch.FloatTensor')
        stimulus_set = normalize(stimulus_set)

        # Present the stimuli to the layers of interest, then save the unit's
        # feature (or activation) map in the dictionary. The maps from the 
        # other units are not stored.
        model.eval()
        features = model.features[:layer_idx+1] # 'features' = an nn.Sequential object
        activation = features.forward(stimulus_set).detach().numpy() 
        activation_maps.update({unit_idx: activation[:, unit_idx, ...]})

    return activation_maps

def get_center_responses(activation_maps):
    """Extracts the response of the center unit of each kernel, for the entire
    layer.

    Parameters
    ----------
    activation_maps : dict 
        A dictionary generated by the function 'get_activation_maps'. It has the 
        keys being the indices of the units, and the values being the 
        corresponding feature maps (a.k.a. activation maps). 
    
    Returns
    -------
    A dictionary containing the unit_idx as the keys, and the center responses 
    as the values.
    """
    center_responses = {}
    for unit, maps in activation_maps.items():
        center_unit_x = round((maps.shape[2] + 0.9)/2)  # add 0.9 instead of 1 to prevent rounding error
        center_unit_y = round((maps.shape[1] + 0.9)/2)
        center_responses[unit] = maps[:, center_unit_y, center_unit_x]
    return center_responses
        
# center_responses = get_center_responses(standard_activation)

# for unit, result in center_responses.items():
#     print(unit, result)

def get_permutated_center_responses(layer_idx, pretrained_model, unit_stats, ori_stats, middle_pixel, RF_size,
                                    shape_rgb, background_rgb, test_mix, xn=227, yn=227):
    """
    Creates slightly permutated shapes according to the units' perferences and 
    returns the activations of the center units. Because pretrained networks
    do not have randomness as the brain does, we need to find another way to
    artificially introduce randomness to the response-generation process.

    The output data of this function (the center responses of each stimulus set 
    corresponding to each unit) will then be used to calculate the selectivity
    matrics of each unit using the function get_permutated_selectivity_matrics().
    This allows us to determine how reliable the selectivities are for each unit.
    This is analogous to the "reliability" measure used in Zhou et al.'s paper
    (2000), which used the fractions of consistent side-of-ownership and
    contrast preferences as how reliable a unit's selectivities are (reliability
    of 0.5 means the selectivity is random, 1.0 means it is perfectly consistent).

    Caution: This function takes time!

    Parameters
    ----------
    layer_idx : int
        The index of the layer, starting from 0. For example, 'Conv2' of AlexNet 
        has a layer_idx of 3.
    pretrained_model : torchvision.models.name_of_cnn 
        The pretrained convolutional neural networks model that corresponds to 
        the stimulus_dict. 
    unit_stats : n x 6 numpy array of float
        Each row contains six RF stats (i.e., unit_idx, R_max, CM_y, CM_x, RF_r, 
        and f_nat) about one unit in the convolutional layer. The first dimension
        n is the number of unique units in the convolutional layer.
    ori_stats : n x 1 numpy array of float
        The counterclockwise rotation of the rectangle in radians, from 0 to pi.
    middle_pixel : tuple of 2 int
        The image coordinate (x, y) of the Conv2d (or MaxPool2d) layer if the 
        unit is back-projected onto the first input layer, that is, the center 
        of the theoretical maximal RF.
    RF_size : int
        The theoretical maximal size of the unit's RF of that layer. Can be 
        calculated using the function 'checkcnn.get_RF_sizes(model)'. 
    shape_rgb : tuple of 3 floats, default=(0.75, 0.75, 0.75)
        The color intensity of the rectangle, from 0.0 to 1.0.
    background_rgb : tuple of 3 floats, default=(0.35, 0.35, 0.35)
        The background color intensity, from 0.0 to 1.0.
    test_mix : tuple of 3 ints
        The number of (standard, overlap, c_figures) you would like to generate.
    xn : int, default=227
        The horizontal width of the stimulus image
    yn : int, default=227
        The vertical length of the stimulus image
    
    Returns
    -------
    permutated_center_responses : dict
        key : unit_idx as the keys
        value : np.array of size (len(mix), 4)
            Contains the center responses of each stimulus set (there are 
            len(mix) of them), and there are 4 responses per each stimulus set
            arranged in order [A,C,B,D] (except for c-figure, the order is 
            swapped to [B,D,A,C] to account for the side difference).
        ** This output data is intended to be used as the argument for the 
           function: get_permutated_selectivity_matrics()**
    """
    # determine lx and ly for this layer (make sure the shape is within RF)
    RF_xs = unit_stats[:,3] + math.floor((xn + 1.1)/2) 
    RF_ys = unit_stats[:,2] + math.floor((yn + 1.1)/2)
    farthest_RF_x = RF_xs[np.argmax(abs(RF_xs - middle_pixel[0]))]
    farthest_RF_y = RF_ys[np.argmax(abs(RF_ys - middle_pixel[1]))]
    max_shape_size = get_max_shape_size(RF_size, middle_pixel, farthest_RF_x, farthest_RF_y) * 2
    lx = ly = max_shape_size
    print(f"farthest_RF_x: {farthest_RF_x}, farthest_RF_y: {farthest_RF_y}, max_shape_size: {max_shape_size}")

    # initializations
    permutated_responses = {}
    test_counter = 0

    for i, test in enumerate(['standard', 'overlap', 'c_figure']):
        num_test = test_mix[i] # the number of this test to do

        for j in range(num_test):
            # permutate stimulus parameters:
            lx_p = random.uniform(0.8, 1.2) * lx
            ly_p = random.uniform(0.8, 1.2) * ly
            ori_noises = (np.random.rand(len(ori_stats)) - 0.5)/5
            ori_stats_p = np.clip(ori_noises + ori_stats, 0, math.pi)
            unit_stats_p = unit_stats.copy()
            CM_y_noises = (np.random.rand(len(ori_stats)) - 0.5)*2
            unit_stats_p[:,2] = unit_stats[:,2] + CM_y_noises
            CM_x_noises = (np.random.rand(len(ori_stats)) - 0.5)*2
            unit_stats_p[:,3] = unit_stats[:,3] + CM_x_noises
            shape_rgb_noises = (np.random.rand(3) - 0.5)/3
            shape_rgb_p = np.clip(shape_rgb + shape_rgb_noises, 0, 1)
            background_rgb_noises = (np.random.rand(3) - 0.5)/3
            background_rgb_p = np.clip(background_rgb + background_rgb_noises, 0, 1)
            
            if test == 'standard':
                # Stimulus      # Side      # Contrast (from L to R) when shape_rgb > background
                #    A             L              light to dark
                #    C             L              dark to light
                #    B             R              light to dark
                #    D             R              dark to light
                stimulus_set = bo.make_standard_sets(unit_stats_p, ori_stats_p, lx=lx_p, ly=ly_p, 
                                        shape_rgb=shape_rgb_p, background_rgb=background_rgb_p)
            elif test == 'overlap':
                # same as standard sets
                stimulus_set = bo.make_overlap_sets(unit_stats_p, ori_stats_p, 
                                    lx_front=lx_p, ly_front=ly_p, lx_back=lx_p, ly_back=ly_p, 
                                    overlap_distance=ly_p/2, overlap_misalignment=lx_p/2,
                                    front_shape_rgb=shape_rgb_p, back_shape_rgb=background_rgb_p)
            elif test == 'c_figure':
                # Stimulus      # Side      # Contrast (from L to R) when shape_rgb > background
                #    A             R              light to dark
                #    C             R              dark to light
                #    B             L              light to dark
                #    D             L              dark to light
                # We will account the difference in Side of c-figures after extracting the center response.
                stimulus_set = bo.make_c_figure_sets(unit_stats_p, ori_stats_p, lx=lx_p*1.3, ly=ly_p*1.7, l_inside=lx_p,
                                                     shape_rgb=shape_rgb_p, background_rgb=background_rgb_p)
            
            # Pass in stimuli and generate activation maps
            activations = get_activation_maps(stimulus_set, layer_idx, model=pretrained_model)

            # Get responses of the center units of the kernels from the activation maps
            center_responses = get_center_responses(activations)

            # Extract the center_responses
            for unit, responses in center_responses.items():
                if test_counter == 0:
                    # initialize response matrix if it is the first test ever
                    all_response = np.zeros((np.sum(test_mix), 4))
                    if test == 'c_figure': # account for the side difference in c-figures
                        all_response[test_counter, :] = responses[[2,3,0,1]]
                        # unlikely to enter this condition because the first
                        # test is usually not a c-figure set.
                    else:
                        all_response[test_counter, :] = responses
                    permutated_responses[unit] = all_response
                else: # if not the first test, update the current all_response matrix
                    if test == 'c_figure': # account for the side difference in c-figures
                        permutated_responses[unit][test_counter,:] = responses[[2,3,0,1]]
                    else:
                        permutated_responses[unit][test_counter,:] = responses
            test_counter += 1
    return permutated_responses               


# Tests
if __name__ == '__main__':
    lx = ly = 25
    model = models.alexnet(pretrained=False)
    pretrained_model = models.alexnet(pretrained=True)
    layer_name = 'Conv2'
    unit_idx = [58, 132, 133]
    xn = yn = 227
    test_mix = (2,1,2)

    unit_stats = unit_stats_alex[layer_name][unit_idx,:]
    ori_stats = ori_stats_alex[layer_name][unit_idx,3] * math.pi / 180
    layer_dict = ccnn.get_info_of_all_layers(model, pretrained_model, xn, yn)
    layer_idx = layer_dict[layer_name].layer_index
    middle_pixel = layer_dict[layer_name].middle_pixel
    RF_size = layer_dict[layer_name].RF_size

    permutated_responses = get_permutated_center_responses(layer_idx, pretrained_model, unit_stats, ori_stats, middle_pixel, RF_size,
                                    shape_rgb=(0.75,0.75,0.75), background_rgb=(0.35,0.35,0.35), test_mix=test_mix)
    print(permutated_responses)

"""# Analyze Responses"""

def get_selectivity_matrics(center_responses, threshold_ratio=0.5, too_small=1):
    """
    Calculates the side of ownership index (SOI) and contrast polarity of the 
    neuron's responses to a stimulus set. It also calculates two additional 
    matrics to determine which side and contrast polarity the unit perfers. 

    Parameters
    ----------
    center_responses : dict
        key : int
            The unit's index
        value: np.array or torch.Tensor
            Contains the unit's activations for stimuli A,C,B,D (in that order).
    threshold_ratio : float
        If the nonpreferred/preferred is SMALLER than this ratio, then the 
        preference is considered TRUE. 
    too_small : float
        Output Type 0 SOI and CP when all the offseted responses are smaller 
        than this number. 

    Returns
    -------
    matrics : dict
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            side_of_onwership : float
                Side of ownership index (SOI). The LOWER the SOI, the stronger
                the border-ownership selectivity.
            contrast_polarity : float
                Constrast polariy index. The LOWER the index, the stronger the 
                contrast-polarity selectivity.
            side_matric : float
                Which side does the unit owns? Side > 0 (left), Side < 0 (right).
            contrast_matric : float
                Which contrast (from left to right of the RF) does the unit 
                prefer? Contrast > 0 (light to dark); Contrast < 0 (dark to light).
    """
    matrics = {}
    for unit, response in center_responses.items():
        # Indices [0,1,2,3] correspond to [A,C,B,D]
        response_offset = -min(response)
        A = response[0] + response_offset
        C = response[1] + response_offset
        B = response[2] + response_offset
        D = response[3] + response_offset
        
        # Calculating matrics
        side_of_ownership = min((B + D)/(A + C), (A + C)/(B + D), key=abs) # SOI
        contrast_polarity = min((C + D)/(A + B), (A + B)/(C + D), key=abs) # CP
        side_matric = (A - B) + (C - D) # side > 0 (left); side < 0 (right)
        contrast_matric = (A - C) + (B - D) # contrast > 0 (light -> dark); contrast < 0 (dark -> light)

        # If all numbers are too small, then output the SOI and CP that would 
        # indicate Type 0 selectivity.
        if all(np.array([A,B,C,D]) < too_small):
            side_of_ownership = 1
            contrast_polarity = 1

        # Determine the selectivity using artificial logic
        is_bo = (threshold_ratio*abs(A-B) > abs(A-C) and threshold_ratio*abs(C-D) > abs(B-D) and (A-B)*(C-D)>0)
        is_contrast = (threshold_ratio*abs(A-C) > abs(A-B) and threshold_ratio*abs(B-D) > abs(C-D) and (A-C)*(B-D)>0)
        is_both = (threshold_ratio*abs(A-B) > abs(A-C) or threshold_ratio*abs(C-D) > abs(B-D)) and\
                  (threshold_ratio*abs(A-C) > abs(A-B) or threshold_ratio*abs(B-D) > abs(C-D))

        if is_bo:
            selectivity = 1   # Type 1: selective for border owernship 
        elif is_contrast:
            selectivity = 2   # Type 2: selective for contrast polarity
        elif is_both:
            selectivity = 3   # Type 3: selective for both
        else: 
            selectivity = 0   # Type 4: selective for neither 

        matrics[unit] = {"side_of_ownership": side_of_ownership,
                        "contrast_polarity": contrast_polarity, 
                        "side_matric": side_matric, 
                        "contrast_matric": contrast_matric,
                        "selectivity": selectivity}

    return matrics

def get_permutated_selectivity_matrics(permutated_responses, threshold_ratio=0.5, too_small=1):
    """
    Calculates the side of ownership index (SOI) and contrast polarity of the 
    neuron's responses to a stimulus set. It also calculates two additional 
    matrics to determine which side and contrast polarity the unit perfers. 
    **Unlike get_selectivity_matrics(), this function expects responses from 
    not just one but MULTIPLE TESTS.**

    Parameters
    ----------
    permutated_center_responses : dict
        key : unit_idx as the keys
        value : np.array of size (len(mix), 4)
            Contains the center responses of each stimulus set (there are 
            len(mix) of them), and there are 4 responses per each stimulus set
            arranged in order [A,C,B,D] (except for c-figure, the order is 
            swapped to [B,D,A,C] to account for the side difference).
    threshold_ratio : float
        If the nonpreferred/preferred is SMALLER than this ratio, then the 
        preference is considered TRUE. 
    too_small : float
        Output Type 0 SOI and CP when all the offseted responses are smaller 
        than this number. 

    Returns
    -------
    permutated_matrics : dict
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            side_of_onwership : np.array of size (num_test,)
                Side of ownership index (SOI). The LOWER the SOI, the stronger
                the border-ownership selectivity.
            contrast_polarity : np.array of size (num_test,)
                Constrast polariy index. The LOWER the index, the stronger the 
                contrast-polarity selectivity.
            side_matric : np.array of size (num_test,)
                Which side does the unit owns? Side > 0 (left), Side < 0 (right).
            contrast_matric : np.array of size (num_test,)
                Which contrast (from left to right of the RF) does the unit 
                prefer? Contrast > 0 (light to dark); Contrast < 0 (dark to light).
    """
    num_test = list(permutated_responses.values())[0].shape[0]
    permutated_matrics = {}
    for i in range(num_test):
        # Assemble a temporary dictionary as the input for get_selectivity_matrics()
        temp_dict = {}
        for unit, responses in permutated_responses.items():
            temp_dict[unit] = responses[i,:]
        
        # Get matrics of all units for this test
        matrics_of_the_test = get_selectivity_matrics(temp_dict, threshold_ratio=threshold_ratio)
        
        # Record the matrics in the permutated_matrics (multi-test data)
        for unit, matric in matrics_of_the_test.items():
            # Initialize each matric if it is the first test
            if i == 0:
                permutated_matrics[unit] = {"side_of_ownership": np.zeros((num_test,)),
                                            "contrast_polarity": np.zeros((num_test,)),
                                            "side_matric": np.zeros((num_test,)),
                                            "contrast_matric": np.zeros((num_test,)),
                                            "selectivity": np.zeros((num_test,), dtype=int)}
            # Fill in the values
            permutated_matrics[unit]["side_of_ownership"][i] = matric["side_of_ownership"]
            permutated_matrics[unit]["contrast_polarity"][i] = matric["contrast_polarity"]
            permutated_matrics[unit]["side_matric"][i] = matric["side_matric"]
            permutated_matrics[unit]["contrast_matric"][i] = matric["contrast_matric"]
            permutated_matrics[unit]["selectivity"][i] = matric["selectivity"]

    return permutated_matrics
        
# Test (please run the previous cell to get 'permutate_responses')
if __name__ == '__main__':
    permutated_matrics = get_permutated_selectivity_matrics(permutated_responses, threshold_ratio=0.5, too_small=1)
    print(permutated_matrics[58])

def get_dependability(permutated_matrics):
    """
    Calculates 6 dependability matrics for each unit of a layer. 

    Paramters
    ---------
    permutated_matrics : dict
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            side_of_onwership : np.array of size (num_test,)
                Side of ownership index (SOI). The LOWER the SOI, the stronger
                the border-ownership selectivity.
            contrast_polarity : np.array of size (num_test,)
                Constrast polariy index. The LOWER the index, the stronger the 
                contrast-polarity selectivity.
            side_matric : np.array of size (num_test,)
                Which side does the unit owns? Side > 0 (left), Side < 0 (right).
            contrast_matric : np.array of size (num_test,)
                Which contrast (from left to right of the RF) does the unit 
                prefer? Contrast > 0 (light to dark); Contrast < 0 (dark to light).
    *To get permutated_matrics, run get_permutated_selectivity_matrics()*

    Returns
    -------
    dependabilities : dict (output of get_dependability())
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            SOI_std : float
                The standard deviation of the side-of-ownership index (SOI). 
                The LOWER the SOI_std, the MORE DEPENDABLE a unit is in terms of 
                border-ownership selectivity.
            CPI_std : float
                The standard deviation of the contrast-polarity index (CPI). 
                The LOWER the SOI_std, the MORE DEPENDABLE a unit is in terms of 
                contrast-polarity selectivity.
            side_dependability : float
                The fraction of consistent side-of-ownership preference.
                0.5 = random; 1.0 = perfectly consistent. 
            contrast_dependability : float
                The fraction of consistent contrast-polarity preference.
                0.5 = random; 1.0 = perfectly consistent. 
            major_selectivity: int
                The most frequent selectivity type. 
                (Type1 = BO; Type2 = Contrast; Type3 = Both; Type4 = Neither)
            selectivity_variability : float
                How variable is the selectivity type (a categorical variable).

    Example
    -------
    dependable_matrics = {"side_of_ownership": np.array([0.3, 0.25, 0.31, 0.34, 0.33]),
                        "contrast_polarity": np.array([0.8, 0.77, 0.81, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, 15, 22, 17]),
                        "contrast_matric": np.array([-10, -20, -55, -66, -20]),
                        "selectivity": np.array([1,1,1,0,1], dtype=int)}
    not_dependable_matrics = {"side_of_ownership": np.array([0.3, 0.7, 0.31, 0.9, 0.4]),
                        "contrast_polarity": np.array([0.2, 0.3, 0.5, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, -15, -22, 17]),
                        "contrast_matric": np.array([-10, -20, 55, -66, -20]),
                        "selectivity": np.array([2,2,1,0,1], dtype=int)}
    pm = {1:dependable_matrics, 2:not_dependable_matrics}
    dependabilities = get_dependability(pm)
    print(dependabilities[1])
    print(dependabilities[2])
    """
    # initialization
    dependabilities = {}

    for unit, matrics in permutated_matrics.items():
        SOI = matrics["side_of_ownership"]
        CPI = matrics["contrast_polarity"]
        side = matrics["side_matric"]
        contrast = matrics["contrast_matric"]
        selectivity = matrics["selectivity"]

        SOI_std = np.std(SOI) # lower the better
        CPI_std = np.std(CPI) # lower the better
        side_dependability = max(np.sum(side > 0), np.sum(side < 0)) / len(side) # higher the better
        contrast_dependability = max(np.sum(contrast > 0), np.sum(contrast < 0)) / len(contrast) # higher the better

        # ...Finally, let's use the "1 - sqrt(count_0^2 + ... count_n^2)" formula
        # suggested by Allaj, Erindi. (2017). "Two Simple Measures of Variability 
        # for Categorical Data." Journal of Applied Statistics. 1-20. 
        # 10.2139/ssrn.2892097. to measure variability for selectivity, a 
        # categorical variable.
        _, counts = np.unique(selectivity, return_counts=True)
        selectivity_variability = 1 - np.sqrt(np.sum(np.power(counts,2)))
        major_selectivity = Counter(selectivity).most_common(1)[0][0]

        dependabilities[unit] = {"SOI_std":SOI_std,
                                 "CPI_std":CPI_std,
                                 "side_dependability":side_dependability,
                                 "contrast_dependability":contrast_dependability,
                                 "major_selectivity":major_selectivity,
                                 "selectivity_variability":selectivity_variability}
    return dependabilities


# Test
if __name__ == "__main__":
    dependable_matrics = {"side_of_ownership": np.array([0.3, 0.25, 0.31, 0.34, 0.33]),
                        "contrast_polarity": np.array([0.8, 0.77, 0.81, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, 15, 22, 17]),
                        "contrast_matric": np.array([-10, -20, -55, -66, -20]),
                        "selectivity": np.array([1,1,1,0,1], dtype=int)}
    not_dependable_matrics = {"side_of_ownership": np.array([0.3, 0.7, 0.31, 0.9, 0.4]),
                        "contrast_polarity": np.array([0.2, 0.3, 0.5, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, -15, -22, 17]),
                        "contrast_matric": np.array([-10, -20, 55, -66, -20]),
                        "selectivity": np.array([2,2,1,0,1], dtype=int)}
    pm = {1:dependable_matrics, 2:not_dependable_matrics}
    dependabilities = get_dependability(pm)
    print(dependabilities[1])
    print(dependabilities[2])

"""# Plot Results"""

def plot_strength_of_selectivity(matrics, layer_name, test_name, N=25, figsize=(15,15)):
    """
    Plots a scatterplot of a single layer (result from a single test).
        x-axis : side of ownership index (SOI)
        y-axis : contrast polarity index (CPI)
    There are also histrograms on the top and on the right to show the 
    distributions of the matrics.

    Paramters
    ---------
    matrics : dict (output of get_selectivity_matrics())
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            side_of_onwership : float
                Side of ownership index (SOI). The LOWER the SOI, the stronger
                the border-ownership selectivity.
            contrast_polarity : float
                Constrast polariy index. The LOWER the index, the stronger the 
                contrast-polarity selectivity.
            side_matric : float
                Which side does the unit owns? Side > 0 (left), Side < 0 (right).
            contrast_matric : float
                Which contrast (from left to right of the RF) does the unit 
                prefer? Contrast > 0 (light to dark); Contrast < 0 (dark to light).

    layer_name : str
        The name of layer (e.g., Conv2). Used for giving a title to the plot.
    test_name : str
        The name of test (e.g., overlap). Used for giving a title to the plot.
    N : int
        The number of units to have their indices displayed on the scatter plot.

    Example
    -------
    lx = ly = 120
    model = models.alexnet(pretrained=False)
    pretrained_model = models.alexnet(pretrained=True)
    layer_name = 'Conv2'
    unit_idx = 58
    xn = yn = 227

    unit_stats = np.expand_dims(unit_stats_alex[layer_name][unit_idx,:], axis=0)
    ori_stats = np.expand_dims(ori_stats_alex[layer_name][unit_idx,3] * math.pi / 180, axis=0)
    layer_dict = ccnn.get_info_of_all_layers(model, pretrained_model, xn, yn)
    layer_idx = layer_dict[layer_name].layer_index

    standard_dict = bo.make_standard_sets(unit_stats, ori_stats, lx=lx, ly=ly)
    standard_activation = get_activation_maps(standard_dict, layer_idx, model=pretrained_model)
    center_responses = get_center_responses(standard_activation)

    matrics = get_selectivity_matrics(center_responses)
    plot_strength_of_selectivity(matrics, layer_name, "standard")
    plt.show()
    """
    # Intialize the plots
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(4, 4)
    ax_scatter = fig.add_subplot(gs[1:4,0:3]) # center scatterplot
    ax_hist_x = fig.add_subplot(gs[0,0:3]) # histrogram above
    ax_hist_y = fig.add_subplot(gs[1:4,3]) # histogram on the right

    colors = ['k', 'r', 'g', 'y']
    labels = ['Type 0', 'Type 1', 'Type 2', 'Type 3']
    unit_idx = []
    side_of_ownership = []
    contrast_polarity = []

    for unit, data in matrics.items(): 
        # Extract matrics from the dictionary
        unit_idx.append(unit)
        side_of_ownership.append(data["side_of_ownership"])
        contrast_polarity.append(data["contrast_polarity"])
        selectivity = data["selectivity"]

        # Color the data points according to selectivity
        ax_scatter.scatter(data["side_of_ownership"], data["contrast_polarity"], c=colors[selectivity]) 

    fig.suptitle(f"Strength of selectivity of {layer_name} with {test_name} test (n={len(side_of_ownership)})\n\
    Type0:black, Type1:red, Type2:green, Type3:yellow")

    # Plot historgrams
    ax_hist_x.hist(side_of_ownership) 
    ax_hist_y.hist(contrast_polarity, orientation="horizontal") 
    # Add labels
    ax_scatter.set_xlabel("Side-of-ownership discrimination")
    ax_scatter.set_ylabel("Contrast polarity discrimination")

    # Get indices of the N units that have the lowest SOI and CP
    top_side_of_ownership_idx = sorted(range(len(side_of_ownership)), key = lambda sub: side_of_ownership[sub])[:N+1]
    top_contrast_polarity_idx = sorted(range(len(contrast_polarity)), key = lambda sub: contrast_polarity[sub])[:N+1]
    # Remove duplicate indices
    top_reliability_idx = top_side_of_ownership_idx + list(set(top_contrast_polarity_idx) - set(top_side_of_ownership_idx))
    # Plot (overlay) the scatter plots of the units of top reliability
    top_units = []
    top_side_of_ownership = []
    top_contrast_polarity = []
    for i in top_reliability_idx:
        top_units.append(unit_idx[i])
        top_side_of_ownership.append(side_of_ownership[i])
        top_contrast_polarity.append(contrast_polarity[i])

    # Show the unit's index on the scatter plot.
    #ax_scatter.scatter(top_side_of_ownership, top_contrast_polarity)
    for unit, x, y in zip(top_units, top_side_of_ownership, top_contrast_polarity):
        ax_scatter.annotate(f"#{unit}", (x,y))


# Test
if __name__ == "__main__":
    lx = ly = 120
    model = models.alexnet(pretrained=False)
    pretrained_model = models.alexnet(pretrained=True)
    layer_name = 'Conv2'
    unit_idx = 58
    xn = yn = 227

    unit_stats = np.expand_dims(unit_stats_alex[layer_name][unit_idx,:], axis=0)
    ori_stats = np.expand_dims(ori_stats_alex[layer_name][unit_idx,3] * math.pi / 180, axis=0)
    layer_dict = ccnn.get_info_of_all_layers(model, pretrained_model, xn, yn)
    layer_idx = layer_dict[layer_name].layer_index

    standard_dict = bo.make_standard_sets(unit_stats, ori_stats, lx=lx, ly=ly)
    standard_activation = get_activation_maps(standard_dict, layer_idx, model=pretrained_model)
    center_responses = get_center_responses(standard_activation)

    matrics = get_selectivity_matrics(center_responses)
    plot_strength_of_selectivity(matrics, layer_name, "standard")
    plt.show()

def plot_dependabilities(dependabilities, layer_name, N=25, figsize=(24,8)):
    """
    Plots 3 plots for a single layer (result from multiple tests).
        Plot #1: scatter plot
            x-axis : std of side of ownership index (SOI_std)
            y-axis : std of contrast polarity index (CPI_std)
        Plot #2: scatter plot
            x-axis : fraction of consistent side preference (side_dependability)
            y-axis : fraction of consistent contrast preference (contrast_dependability)
            *0.5 = random; 1.0 = perfectly consistent*
        Plot #3: bar plot
            x-axis : major_selectivity (categorical)
            y-axis : selectivity_variability

    Paramters
    ---------
    dependabilities : dict (output of get_dependability())
        key : int
            The unit's index.
        value : dict (keys are the name of the values)
            SOI_std : float
                The standard deviation of the side-of-ownership index (SOI). 
                The LOWER the SOI_std, the MORE DEPENDABLE a unit is in terms of 
                border-ownership selectivity.
            CPI_std : float
                The standard deviation of the contrast-polarity index (CPI). 
                The LOWER the SOI_std, the MORE DEPENDABLE a unit is in terms of 
                contrast-polarity selectivity.
            side_dependability : float
                The fraction of consistent side-of-ownership preference.
                0.5 = random; 1.0 = perfectly consistent. 
            contrast_dependability : float
                The fraction of consistent contrast-polarity preference.
                0.5 = random; 1.0 = perfectly consistent. 
            major_selectivity: int
                The most frequent selectivity type. 
                (Type1 = BO; Type2 = Contrast; Type3 = Both; Type4 = Neither)
            selectivity_variability : float
                How variable is the selectivity type (a categorical variable).     
    layer_name : str
        The name of layer (e.g., Conv2). Used for giving a title to the plot.
    N : int
        The number of units to have their indices displayed on the scatter plot.

    Example
    -------
    dependable_matrics = {"side_of_ownership": np.array([0.3, 0.25, 0.31, 0.34, 0.33]),
                        "contrast_polarity": np.array([0.8, 0.77, 0.81, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, 15, 22, 17]),
                        "contrast_matric": np.array([-10, -20, -55, -66, -20]),
                        "selectivity": np.array([1,1,1,0,1], dtype=int)}
    not_dependable_matrics = {"side_of_ownership": np.array([0.3, 0.7, 0.31, 0.9, 0.4]),
                        "contrast_polarity": np.array([0.2, 0.3, 0.5, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, -15, -22, 17]),
                        "contrast_matric": np.array([-10, -20, 55, -66, -20]),
                        "selectivity": np.array([2,2,1,0,1], dtype=int)}
    pm = {1:dependable_matrics, 2:not_dependable_matrics}
    dependabilities = get_dependability(pm)
    plot_dependabilities(dependabilities, layer_name, N=25, figsize=(24,8))
    """
    # Intializations
    plt.figure(figsize=figsize)
    num_units = len(list(dependabilities.keys()))
    colors = ['k', 'r', 'g', 'y']
    labels = ['Type 0', 'Type 1', 'Type 2', 'Type 3']
    unit_idx = np.zeros((num_units,))
    SOI_std = np.zeros((num_units,))
    CPI_std = np.zeros((num_units,))
    side_dependability = np.zeros((num_units,))
    contrast_dependability = np.zeros((num_units,))
    major_selectivity = np.zeros((num_units,), dtype=int)
    selectivity_variability = np.zeros((num_units,))

    for i, (unit, dependability) in enumerate(dependabilities.items()): 
        # Extract matrics from the dictionary
        unit_idx[i] = unit
        SOI_std[i] = dependability["SOI_std"]
        CPI_std[i] = dependability["CPI_std"]
        side_dependability[i] = dependability["side_dependability"]
        contrast_dependability[i] = dependability["contrast_dependability"]
        major_selectivity[i] = dependability["major_selectivity"]
        selectivity_variability[i] = dependability["selectivity_variability"]

    plt.suptitle(f"Dependability meausures in {layer_name} (n={len(unit_idx)})\n\
    Type0:black, Type1:red, Type2:green, Type3:yellow")

    # Plot #1: scatter plot
    #         x-axis : std of side of ownership index (SOI_std)
    #         y-axis : std of contrast polarity index (CPI_std)
    plt.subplot(1,3,1)
    for i, unit in enumerate(unit_idx):
        # Color the data points according to the most common selectivity
        plt.scatter(SOI_std[i], CPI_std[i], c=colors[major_selectivity[i]]) 

        # Add labels
        plt.xlabel("SOI_std")
        plt.ylabel("CPI_std")
        plt.title(f"Standard deviations of SOI and CPI")

        # Get indices of the N units that have the lowest SOI_std and CPI_std
        top_SOI_std_idx = sorted(range(len(SOI_std)), key = lambda sub: SOI_std[sub])[:N+1]
        top_CPI_std_idx = sorted(range(len(CPI_std)), key = lambda sub: CPI_std[sub])[:N+1]
        # Remove duplicate indices
        top_std_idx = top_SOI_std_idx + list(set(top_CPI_std_idx) - set(top_SOI_std_idx))
        # Plot (overlay) the scatter plots of the units of top reliability
        top_units = []
        top_SOI_std = []
        top_CPI_std = []
        for j in top_std_idx :
            top_units.append(unit_idx[j])
            top_SOI_std.append(SOI_std[j])
            top_CPI_std.append(CPI_std[j])

        # Show the unit's index on the scatter plot.
        for unit, x, y in zip(top_units, top_SOI_std, top_CPI_std):
            plt.annotate(f"#{unit}", (x,y))

    # Plot #2: scatter plot
    #         x-axis : fraction of consistent side preference (side_dependability)
    #         y-axis : fraction of consistent contrast preference (contrast_dependability)
    plt.subplot(1,3,2)
    for i, unit in enumerate(unit_idx):
        # Color the data points according to the most common selectivity
        plt.scatter(side_dependability[i], contrast_dependability[i], c=colors[major_selectivity[i]]) 

        # Add labels
        plt.xlabel("side_dependability")
        plt.ylabel("contrast_dependability")
        plt.title(f"Fraction of consistent side/contrast preferences")

        # Get indices of the N units that have the HIGHEST side_dependability and contrast_dependability
        top_side_dependability_idx = sorted(range(len(side_dependability)), key = lambda sub: side_dependability[sub], reverse=True)[:N+1]
        top_contrast_dependability_idx = sorted(range(len(contrast_dependability)), key = lambda sub: contrast_dependability[sub], reverse=True)[:N+1]
        # Remove duplicate indices
        top_std_idx = top_side_dependability_idx + list(set(top_contrast_dependability_idx) - set(top_side_dependability_idx))
        # Plot (overlay) the scatter plots of the units of top reliability
        top_units = []
        top_side_dependability = []
        top_contrast_dependability = []
        for j in top_std_idx :
            top_units.append(unit_idx[j])
            top_side_dependability.append(side_dependability[j])
            top_contrast_dependability.append(contrast_dependability[j])

        # Show the unit's index on the scatter plot.
        for unit, x, y in zip(top_units, top_side_dependability, top_contrast_dependability):
            plt.annotate(f"#{unit}", (x,y))

    # Plot #3: box plot
    #         x-axis : major_selectivity (categorical)
    #         y-axis : selectivity_variability
    plt.subplot(1,3,3)
    type0 = selectivity_variability[major_selectivity==0]
    type1 = selectivity_variability[major_selectivity==1]
    type2 = selectivity_variability[major_selectivity==2]
    type3 = selectivity_variability[major_selectivity==3]
    all_types = [type0, type1, type2, type3]
    plt.boxplot(all_types)
    plt.title(f"There are {len(type0)} type0, {len(type1)} type1, {len(type2)} type2, and {len(type3)} type3.")
    plt.xlabel('selectivity type')
    plt.ylabel('selectivity_variability')
    plt.xticks([1, 2, 3, 4], [0, 1, 2, 3])

    plt.show()


# Test
if __name__ == "__main__":
    dependable_matrics = {"side_of_ownership": np.array([0.3, 0.25, 0.31, 0.34, 0.33]),
                        "contrast_polarity": np.array([0.8, 0.77, 0.81, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, 15, 22, 17]),
                        "contrast_matric": np.array([-10, -20, -55, -66, -20]),
                        "selectivity": np.array([1,1,1,0,1], dtype=int)}
    not_dependable_matrics = {"side_of_ownership": np.array([0.3, 0.7, 0.31, 0.9, 0.4]),
                        "contrast_polarity": np.array([0.2, 0.3, 0.5, 0.82, 0.79]),
                        "side_matric": np.array([10, 20, -15, -22, 17]),
                        "contrast_matric": np.array([-10, -20, 55, -66, -20]),
                        "selectivity": np.array([2,2,1,0,1], dtype=int)}
    pm = {1:dependable_matrics, 2:not_dependable_matrics}
    dependabilities = get_dependability(pm)
    plot_dependabilities(dependabilities, layer_name, N=25, figsize=(24,8))

def visualize_activation(img_tensor, pretrained_model, layer_idx, unit_indices):
    """
    Plots the activations of the species units.

    Parameters
    ----------
    img_tensor: torch.Tensor
        The input tensor.
    pretrained_model: torch.models
        The pretrained PyTorch model.
    layer_idx: int
        The layer index.
    unit_indices: list of int
        The indices of the units.
        
    Example
    -------
    pretrained_model = models.alexnet(pretrained=True)
    img0 = bo.make_standard_set(75,125,15).astype(np.float32)

    transform = transforms.Compose([transforms.Resize([227,227]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                        std =[0.229, 0.224, 0.225])])
    plt.figure()
    plt.imshow(img0[0].transpose(1,2,0))
    img0_tensor = torch.unsqueeze(torch.tensor(img0[0]), dim=0)
    activations_dict2 = visualize_activation(img0_tensor, pretrained_model, 3, [132])
    plt.show()
    """
    # Present the stimuli to the layers of interest
    pretrained_model.eval()
    features = pretrained_model.features[:layer_idx+2] # 'features' = an nn.Sequential object
    activation = features.forward(img_tensor).detach()
    
    activations_dict = {}
    num_units = len(unit_indices)
    
    plt.figure(figsize=(12,14))
    for i, unit in enumerate(unit_indices):
        plt.subplot(math.ceil(num_units/8), 8, i+1)
        act = activation[:,unit,:,:]
        plt.imshow(np.squeeze(act.numpy()), cmap='gray', vmax=3, vmin=-3)
        plt.title(f"Unit #{unit}")
        activations_dict[unit] = act
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    return activations_dict


# Test
if __name__ == "__main__":
    pretrained_model = models.alexnet(pretrained=True)
    img0 = bo.make_standard_set(75,125,15).astype(np.float32)

    transform = transforms.Compose([transforms.Resize([227,227]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std =[0.229, 0.224, 0.225])])
    plt.figure()
    plt.imshow(img0[0].transpose(1,2,0))
    img0_tensor = torch.unsqueeze(torch.tensor(img0[0]), dim=0)
    activations_dict2 = visualize_activation(img0_tensor, pretrained_model, 3, [132])
    plt.show()

def visualize_kernel(pretrained_model, layer_idx, unit_idx):
    """
    Plots all the channels of a single kernel (a.k.a. unit) of the layer.

    Parameters
    ----------
    pretrained_model: torch.models
        The pretrained PyTorch model.
    layer_idx: int
        The layer index.
    unit_idx: int
        The unit index.
        
    Example
    -------
    pretrained_model = models.alexnet(pretrained=True)
    visualize_kernel(pretrained_model, 3, 58)
    """
    weights = pretrained_model.features[layer_idx].weight.detach().cpu()
    print(weights.shape)
    num_kernels = weights.shape[1]
    weights_previous_layer = pretrained_model.features[layer_idx-3].weight.detach().cpu()
    plt.figure(figsize=(12,16))
    for i in range(num_kernels):
        plt.subplot(math.ceil(num_kernels/8), 8, i+1)
        if num_kernels == 3:
            cmaps = ['Reds','Greens','Blues']
            plt.imshow(weights[unit_idx,i,:,:], cmap=cmaps[i], vmax=0.1, vmin=-0.1)
        else:
            plt.imshow(weights[unit_idx,i,:,:], cmap='gray', vmax=0.1, vmin=-0.1)
        
        plt.title(f"Unit {unit_idx} Ch.{i}")
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)


# Test
if __name__ == "__main__":
    pretrained_model = models.alexnet(pretrained=True)
    visualize_kernel(pretrained_model, 3, 58)

def visualize_kernels(pretrained_model, layer_idx, achromatic=True):
    """
    Plots ALL the kernels/units of a layer.

    Parameters
    ----------
    pretrained_model: torch.models
        The pretrained PyTorch model.
    layer_idx: int
        The layer index.
    achromatic: bool
        True if you only want to plot a single channel of each unit.
        False if each unit has 3 channels (RGB) and you wish to plot them in colors.

    Example
    -------
    pretrained_model = models.alexnet(pretrained=True)
    visualize_kernels(pretrained_model, 0, achromatic=False)
    """
    weights = pretrained_model.features[layer_idx].weight.detach().cpu()
    num_kernels = weights.shape[0]

    plt.figure(figsize=(12,30))
    for i in range(num_kernels):
        plt.subplot(math.ceil(num_kernels/4)+1, 8, i+1)

        if achromatic:
            plt.imshow(np.clip((225*weights[i,0,:,:]).numpy(),0,225).astype(np.uint8), cmap='gray')
            # needs np.clip because np.uint8 will wrap (e.g. 226 -> 1)
        else: 
            # 3 color channels
            plt.imshow(np.clip((500*weights[i,:,:,:] + 100).permute(1,2,0).numpy(),0,225).astype(np.uint8))
            # needs np.clip because np.uint8 will wrap (e.g. 226 -> 1)

        plt.title(f"Unit #{i}")
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)


# Test
if __name__ == "__main__":
    pretrained_model = models.alexnet(pretrained=True)
    visualize_kernels(pretrained_model, 0, achromatic=False)

def visualize_single_convolution(activations_dict, pretrained_model, layer_idx, unit_idx):
    """
    Plots and returns the 2D convolution result of a single unit. 

    To get the 2D convolution result of a layer, we can do:

        features = model.features[:layer_idx]
        activation = features.forward(stimulus_set).detach().numpy() 

    However, the activation here already sums up the contributions of all the 
    units. If we want to see the activations of individual units in this layer,
    we have to do the convolution ourselves. This is what this function is for--
    to see what a single unit does to the activation map from its previous layer.

    Parameters
    ----------
    activations_dict: dict
        The output of visualize_activation.
        key: int
            The unit index.
        value: torch.Tensor
            The activation map of the previous layer.
    
    Returns
    -------
    convolution_dict: dict
        key: int
            The unit index.
        value: numpy.array

    Example
    -------
    layer_idx = 3   # 3 is for Conv2
    unit_idx = 132
    pretrained_model = models.alexnet(pretrained=True)

    img0 = bo.draw_rectangle(125, 125, lx=35, ly=35, angle=0.01).astype(np.float32)
    bo.plot_one_stimulus(img0)
    img0_tensor = torch.tensor(img0)
    
    # get activation of MaxPool1
    activations_dict1 = visualize_activation(torch.unsqueeze(img0_tensor,dim=0), pretrained_model, 2, [i for i in range(64)])
    plt.show()

    # get the result of convolution of unit 132 
    convolution_dict = visualize_single_convolution(activations_dict1, pretrained_model, layer_idx, unit_idx)
    plt.show()
    """
    stride = pretrained_model.features[layer_idx].stride
    padding = pretrained_model.features[layer_idx].padding

    convolution_dict = {}
    num_units = len(activations_dict)
    plt.figure(figsize=(12,14))
    plt_counter = 1
    for unit, act in activations_dict.items():
        plt.subplot(8,8,plt_counter)
        plt_counter += 1
        act = torch.unsqueeze(act, dim=0)
        w = torch.unsqueeze(pretrained_model.features[layer_idx].weight[unit_idx,unit,:,:], dim=0)
        w = torch.unsqueeze(w, dim=1)
        a = F.conv2d(act, w, stride=stride, padding=padding)
        
        conv_result = a[0,0,:,:].detach().numpy()
        convolution_dict[unit] = conv_result
        plt.imshow(conv_result, cmap='gray', vmax=2, vmin=-2)
        plt.title(f"conv of {unit}")
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    return convolution_dict


# Test
if __name__ == "__main__":
    layer_idx = 3   # 3 is for Conv2
    unit_idx = 132
    pretrained_model = models.alexnet(pretrained=True)

    img0 = bo.draw_rectangle(125, 125, lx=35, ly=35, angle=0.01).astype(np.float32)
    bo.plot_one_stimulus(img0)
    img0_tensor = torch.tensor(img0)

    print('Activations from MaxPool1')
    activations_dict1 = visualize_activation(torch.unsqueeze(img0_tensor,dim=0), pretrained_model, 2, [i for i in range(64)])
    plt.show()

    print(f"unit {unit_idx} convolution result of the activation of a previous layer")
    convolution_dict = visualize_single_convolution(activations_dict1, pretrained_model, layer_idx, unit_idx)
    plt.show()

"""# Other Techniques
* Guided Backprop
"""

class GuidedBackprop():
    """
       Produces gradients generated with guided back propagation from the given image
       ### Original code source: https://github.com/utkuozbulak/pytorch-cnn-visualizations.git
    """
    def __init__(self, pretrained_model):
        self.pretrained_model = pretrained_model
        self.gradients = None
        self.outputs = []
        self.forward_relu_outputs = []
        # Put pretrained_model in evaluation mode
        self.pretrained_model.eval()
        self.update_relus()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
        # Register hook to the first layer
        first_layer = list(self.pretrained_model.features._modules.items())[0][1]
        first_layer.register_backward_hook(hook_function)

    def update_relus(self):
        """
            Updates relu activation functions so that
                1- stores output in forward pass
                2- imputes zero for gradient values that are less than zero
        """
        def relu_backward_hook_function(module, grad_in, grad_out):
            """
            If there is a negative gradient, change it to zero
            """
            # Get last forward output
            corresponding_forward_output = self.forward_relu_outputs[-1]
            corresponding_forward_output[corresponding_forward_output > 0] = 1
            modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
            del self.forward_relu_outputs[-1]  # Remove last forward output
            return (modified_grad_out,)

        def relu_forward_hook_function(module, ten_in, ten_out):
            """
            Store results of forward pass
            """
            self.forward_relu_outputs.append(ten_out)

        # Loop through layers, hook up ReLUs
        for pos, module in self.pretrained_model.features._modules.items():
            if isinstance(module, nn.ReLU):
                module.register_backward_hook(relu_backward_hook_function)
                module.register_forward_hook(relu_forward_hook_function)

    def generate_gradients(self, input_image, layer_index, unit_index, center_pos):
        self.pretrained_model.zero_grad()
        # Forward pass
        features = self.pretrained_model.features[:layer_index+1] # 'features' = an nn.Sequential object
        x = Variable(torch.from_numpy(input_image), requires_grad=True)
        self.outputs = features.forward(x)

        # Target for backprop - find max response from the target layer, target filter
        one_hot_output = torch.zeros(self.outputs.shape, dtype=float)
        one_hot_output[0,unit_index,center_pos[1],center_pos[0]] = self.outputs[0,unit_index,center_pos[1],center_pos[0]]
        y = torch.tensor(one_hot_output, requires_grad=True)

        # Backward pass
        self.outputs.backward(gradient=y)

        # Convert Pytorch variable to numpy array
        # [0] to get rid of the first channel (1,3,224,224)
        return self.gradients.data.numpy()[0]

def norm_im(im):
    """Normalizes image."""
    im = im - im.min()
    if (abs(im.max()) < 0.1): # prevent rounding error
        im = im * 0
    else: 
        im = im /im.max()
    return im


# Test
if __name__ == "__main__": 
    input = bo.draw_rectangle(111, 111, 100, 20, angle=0.3)
    input = np.expand_dims(input, axis=0)

    pretrained_model = models.alexnet(pretrained=True)
    pretrained_model.double()

    GBP = GuidedBackprop(pretrained_model)
    guided_grads = GBP.generate_gradients(input, 3, 15, (12,12))

    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(np.transpose(np.squeeze(input), (1, 2, 0))) 

    plt.subplot(1,2,2)
    im = norm_im(guided_grads)
    plt.imshow(np.transpose(im, (1, 2, 0))) 
    plt.show()