borderownership / src / rf_mapping / rfmpz / rf_mapping_live(standalone).ipynb
rf_mapping_live(standalone).ipynb
Raw
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RF mapping amination: a stand-alone notebook\n",
    "\n",
    "Tony Fu, July 2022\n",
    "\n",
    "Used to map the RFs of the earlier layers of neural networks.\n",
    "\n",
    "Note: I simplified many functions here. Please see my github (https://github.com/tonyfu97/borderownership) for full implementations. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import math\n",
    "\n",
    "import numpy as np\n",
    "from numba import njit\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import models\n",
    "# from torchvision.models import AlexNet_Weights, VGG16_Weights\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the unit (i.e., convolutional kernel) of interest:\n",
    "\n",
    "**NOTE**: conv_i = 0 means conv1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.alexnet(pretrained=True)\n",
    "model_name = 'alexnet'\n",
    "# model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)\n",
    "# model_name = 'vgg16'\n",
    "\n",
    "conv_i = 0\n",
    "unit_i = 0\n",
    "animation_interval_ms = 10  # If too short, the animation may not plot some of the bars (but will still input all the bars to the model)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify the mapping parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "blen_rf_ratios = np.array([48/64, 24/64, 12/64, 6/64])\n",
    "aspect_ratios = np.array([1/2, 1/5, 1/10])\n",
    "orientations = np.arange(0, 180, 22.5)\n",
    "fgvals = [(-1, -1, -1), ( 1,  1,  1), ( 1, -1, -1), ( 0,  0,  0), (-1,  1, -1), ( 0,  0,  0), (-1, -1,  1), ( 0,  0,  0)]  # rgb, must be of the same length as bgvals\n",
    "bgvals = [( 1,  1,  1), (-1, -1, -1), (  0, 0,  0), ( 1, -1, -1), ( 0,  0,  0), (-1,  1, -1), ( 0,  0,  0), (-1, -1,  1)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper functions and classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Bar drawing functions (written by Dr. Bair, slightly modified by me)\n",
    "\n",
    "Note1: njit is a just-in-time compiler that is recommended for functions\n",
    "       that are called very often. The stimfr_bar() function is sped up\n",
    "       by about 180x thanks to it. See Numba documentation for details.\n",
    "      \n",
    "Note2: The y-axis points DOWN.\n",
    "\"\"\"\n",
    "\n",
    "@njit\n",
    "def clip(val, vmin, vmax):\n",
    "    \"\"\"Limits value to be vmin <= val <= vmax\"\"\"\n",
    "    if vmin > vmax:\n",
    "        raise Exception(\"vmin should be smaller than vmax.\")\n",
    "    val = min(val, vmax)\n",
    "    val = max(val, vmin)\n",
    "    return val\n",
    "\n",
    "@njit\n",
    "def rotate(dx, dy, theta_deg):\n",
    "    \"\"\"\n",
    "    Applies the rotation matrix:\n",
    "        [dx, dy] * [[a, b], [c, d]]  = [dx_r, dy_r]\n",
    "    \n",
    "    To undo rotation, apply again with negative theta_deg.\n",
    "    Modified from Dr. Wyeth Bair's d06_mrf.py\n",
    "    \"\"\"\n",
    "    thetar = theta_deg * math.pi / 180.0\n",
    "    rot_a = math.cos(thetar); rot_b = math.sin(thetar)\n",
    "    rot_c = math.sin(thetar); rot_d = -math.cos(thetar)\n",
    "    # Usually, the negative sign appears in rot_c instead of rot_d, but the\n",
    "    # y-axis points downward in our case.\n",
    "    dx_r = rot_a*dx + rot_c*dy\n",
    "    dy_r = rot_b*dx + rot_d*dy\n",
    "    return dx_r, dy_r\n",
    "\n",
    "@njit\n",
    "def stimfr_bar(xn, yn, x0, y0, theta, blen, bwid, aa, fgval, bgval):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    xn    - (int) horizontal width of returned array\\n\n",
    "    yn    - (int) vertical height of returned array\\n\n",
    "    x0    - (float) horizontal offset of center (pix)\\n\n",
    "    y0    - (float) vertical offset of center (pix)\\n\n",
    "    theta - (float) orientation (pix)\\n\n",
    "    blen  - (float) length of bar (pix)\\n\n",
    "    bwid  - (float) width of bar (pix)\\n\n",
    "    aa    - (float) length scale for anti-aliasing (pix)\\n\n",
    "    fgval - (float) bar luminance [0..1]\\n\n",
    "    bgval - (float) background luminance [0..1]\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    Return a numpy array that has a bar on a background.\n",
    "    \"\"\"\n",
    "    s = np.full((yn,xn), bgval, dtype='float32')  # Fill w/ BG value\n",
    "    dval = fgval - bgval  # Luminance difference\n",
    "  \n",
    "    # Maximum extent of unrotated bar corners in zero-centered frame:\n",
    "    dx = bwid/2.0\n",
    "    dy = blen/2.0\n",
    "\n",
    "    # Rotate top-left corner from (-dx, dy) to (dx1, dy1)\n",
    "    dx1, dy1 = rotate(-dx, dy, theta)\n",
    "    \n",
    "    # Rotate top-right corner from (dx, dy) to (dx2, dy2)\n",
    "    dx2, dy2 = rotate(dx, dy, theta)\n",
    "\n",
    "    # Maximum extent of rotated bar corners in zero-centered frame:\n",
    "    maxx = aa + max(abs(dx1), abs(dx2))\n",
    "    maxy = aa + max(abs(dy1), abs(dy2))\n",
    "    \n",
    "    # Center of stimulus field\n",
    "    xc = (xn-1.0)/2.0\n",
    "    yc = (yn-1.0)/2.0\n",
    "\n",
    "    # Define the 4 corners a box that contains the rotated bar.\n",
    "    bar_left_i   = round(xc + x0 - maxx)\n",
    "    bar_right_i  = round(xc + x0 + maxx) + 1\n",
    "    bar_top_i    = round(yc + y0 - maxy)\n",
    "    bar_bottom_i = round(yc + y0 + maxy) + 1\n",
    "\n",
    "    bar_left_i   = clip(bar_left_i  , 0, xn)\n",
    "    bar_right_i  = clip(bar_right_i , 0, xn)\n",
    "    bar_top_i    = clip(bar_top_i   , 0, yn)\n",
    "    bar_bottom_i = clip(bar_bottom_i, 0, yn)\n",
    "\n",
    "    for i in range(bar_left_i, bar_right_i):  # for i in range(0,xn):\n",
    "        xx = i - xc - x0  # relative to bar center\n",
    "        for j in range (bar_top_i, bar_bottom_i):  # for j in range (0,yn):\n",
    "            yy = j - yc - y0  # relative to bar center\n",
    "            x, y = rotate(xx, yy, -theta)  # rotate back\n",
    "\n",
    "            # Compute distance from bar edge, 'db'\n",
    "            if x > 0.0:\n",
    "                dbx = bwid/2 - x  # +/- indicates inside/outside\n",
    "            else:\n",
    "                dbx = x + bwid/2\n",
    "\n",
    "            if y > 0.0:\n",
    "                dby = blen/2 - y  # +/- indicates inside/outside\n",
    "            else:\n",
    "                dby = y + blen/2\n",
    "\n",
    "            if dbx < 0.0:  # x outside\n",
    "                if dby < 0.0:\n",
    "                    db = -math.sqrt(dbx*dbx + dby*dby)  # Both outside\n",
    "                else:\n",
    "                    db = dbx\n",
    "            else:  # x inside\n",
    "                if dby < 0.0:  # y outside\n",
    "                    db = dby\n",
    "                else:  # Both inside - take the smallest distance\n",
    "                    if dby < dbx:\n",
    "                        db = dby\n",
    "                    else:\n",
    "                        db = dbx\n",
    "\n",
    "            if aa > 0.0:\n",
    "                if db > aa:\n",
    "                    f = 1.0  # This point is inside the bar\n",
    "                elif db < -aa:\n",
    "                    f = 0.0  # This point is outside the bar\n",
    "                else:  # Use sinusoidal sigmoid\n",
    "                    f = 0.5 + 0.5*math.sin(db/aa * 0.25*math.pi)\n",
    "            else:\n",
    "                if db >= 0.0:\n",
    "                    f = 1.0  # inside\n",
    "                else:\n",
    "                    f = 0.0  # outside\n",
    "\n",
    "            s[j, i] += f * dval  # add a fraction 'f' of the 'dval'\n",
    "\n",
    "    return s\n",
    "\n",
    "\n",
    "#######################################.#######################################\n",
    "#                                                                             #\n",
    "#                               STIMFR_BAR_COLOR                              #\n",
    "#                                                                             #\n",
    "###############################################################################\n",
    "# njit slowed down by 4x\n",
    "def stimfr_bar_color(xn,yn,x0,y0,theta,blen,bwid,aa,r1,g1,b1,r0,g0,b0):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    xn    - (int) horizontal width of returned array\\n\n",
    "    yn    - (int) vertical height of returned array\\n\n",
    "    x0    - (float) horizontal offset of center (pix)\\n\n",
    "    y0    - (float) vertical offset of center (pix)\\n\n",
    "    theta - (float) orientation (pix)\\n\n",
    "    blen  - (float) length of bar (pix)\\n\n",
    "    bwid  - (float) width of bar (pix)\\n\n",
    "    laa   - (float) length scale for anti-aliasing (pix)\\n\n",
    "    r1,g1,b1 - bar color\\n\n",
    "    r0,g0,b0 - background color\\n\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    Return a numpy array (3, yn, xn) that has a bar on a background.\n",
    "    \"\"\"\n",
    "    s = np.empty((3, yn,xn), dtype='float32')\n",
    "    s[0] = stimfr_bar(xn, yn, x0, y0, theta, blen, bwid, aa, r1, r0)\n",
    "    s[1] = stimfr_bar(xn, yn, x0, y0, theta, blen, bwid, aa, g1, g0)\n",
    "    s[2] = stimfr_bar(xn, yn, x0, y0, theta, blen, bwid, aa, b1, b0)\n",
    "    return s\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Since each set of bar stimuli are designed for a single layer, it will be a\n",
    "waste of time to present them to later layers that we are not going to analyze.\n",
    "Therefore, here is a function that truncates the model and returns the output\n",
    "tensor of the specified layer.\n",
    "\"\"\"\n",
    "\n",
    "def truncated_model(x, model, layer_index):\n",
    "    \"\"\"\n",
    "    Returns the output of the specified layer without forward passing to\n",
    "    the subsequent layers.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : torch.tensor\n",
    "        The input. Should have dimension (1, 3, 2xx, 2xx).\n",
    "    model : torchvision.model.Module\n",
    "        The neural network (or the layer if in a recursive case).\n",
    "    layer_index : int\n",
    "        The index of the layer, the output of which will be returned. The\n",
    "        indexing excludes container layers.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    y : torch.tensor\n",
    "        The output of layer with the layer_index.\n",
    "    layer_index : int\n",
    "        Used for recursive cases. Should be ignored.\n",
    "    \"\"\"\n",
    "    # If the layer is not a container, forward pass.\n",
    "    if (len(list(model.children())) == 0):\n",
    "        return model(x), layer_index - 1\n",
    "    else:  # Recurse otherwise.\n",
    "        for sublayer in model.children():\n",
    "            x, layer_index = truncated_model(x, sublayer, layer_index)\n",
    "            if layer_index < 0:  # Stop at the specified layer.\n",
    "                return x, layer_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "According to the internet, a \"hook\" is *a place and usually an interface\n",
    "provided in packaged code that allows a programmer to insert customized\n",
    "programming*. Pytorch allows users to register hook functions in the network.\n",
    "A 'forward hook' is a function that will be automatically evoked whenever the\n",
    "network does a forward pass. Similarly, a 'backward hook' is evoked whenever\n",
    "the network does a backward pass (i.e., backprop). A forward hook must be\n",
    "formatted as follows:\n",
    "\n",
    "            hook_function_name(module, ten_in, ten_out)\n",
    "\n",
    "and a backward hook must be formatted as follows:\n",
    "\n",
    "            hook_function_name(module, grad_in, grad_out)\n",
    "\n",
    "You may overwrite the tensor/gradient output by returning something else. Next,\n",
    "to register a forward hook to a layer object, do this:\n",
    "\n",
    "            layer.register_forward_hook(hook_function_name)\n",
    "\n",
    "To register a backward hook, do this:\n",
    "\n",
    "            layer.register_backward_hook(hook_function_name)\n",
    "\n",
    "Since a hook function cannot return anything except the tensor/gradient output\n",
    "of a layer, it is not very useful on its own. Therefore, it is usually written\n",
    "as a class method, so that it can write to or read from the class attributes\n",
    "when it is evoked. The following block of code uses a hook function to\n",
    "find the maximum sizes of the receptive fields (RF) of all convolutional layers\n",
    "of a given network.\n",
    "\"\"\"\n",
    "\n",
    "class HookFunctionBase:\n",
    "    \"\"\"\n",
    "    A base class that register a hook function to all specified layer types\n",
    "    (excluding all container types) in a given model. The child class must\n",
    "    implement hook_function(). The child class must also call\n",
    "    self.register_forward_hook_to_layers() by itself.\n",
    "    \"\"\"\n",
    "    def __init__(self, model, layer_types):\n",
    "        \"\"\"\n",
    "        Constructs a HookFunctionBase object.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        model : torchvision.models\n",
    "            The neural network.\n",
    "        layer_types : tuple of torch.nn.Modules\n",
    "            A tuple of the layer types you would like to register the forward\n",
    "            hook to. For example, layer_types = (nn.Conv2d, nn.ReLU) means\n",
    "            that all the Conv2d and ReLU layers will be registered with the\n",
    "            forward hook.\n",
    "        \"\"\"\n",
    "        self.model = copy.deepcopy(model)  # DEEPCOPY IS A MUST\n",
    "        self.layer_types = layer_types\n",
    "\n",
    "    def hook_function(self, module, ten_in, ten_out):\n",
    "        raise NotImplementedError(\"Child class of HookFunctionBase must \"\n",
    "                                  \"implement hookfunction(self, module, ten_in, ten_out)\")\n",
    "\n",
    "    def register_forward_hook_to_layers(self, layer):\n",
    "        # If \"model\" is a leave node and matches the layer_type, register hook.\n",
    "        if (len(list(layer.children())) == 0):\n",
    "            if (isinstance(layer, self.layer_types)):\n",
    "                layer.register_forward_hook(self.hook_function)\n",
    "\n",
    "        # Otherwise (i.e.,the layer is a container type layer), recurse.\n",
    "        else:\n",
    "            for sublayer in layer.children():\n",
    "                self.register_forward_hook_to_layers(sublayer)\n",
    "\n",
    "\n",
    "class RfSize(HookFunctionBase):\n",
    "    def __init__(self, model, image_size):\n",
    "        \"\"\"\n",
    "        Constructs a RfSize object that calculates the sizes of the receptive\n",
    "        fields (RFs), assuming all RFs are square. \n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        model : torchvision.Module\n",
    "            The neural network.\n",
    "        image_size : int\n",
    "            The size of the input image. This should not really matter unless\n",
    "            the image size is smaller than the RF size.\n",
    "        \"\"\"\n",
    "        super().__init__(model, layer_types=(torch.nn.Module))\n",
    "        self.image_size = image_size\n",
    "        self.layer_indices = []\n",
    "        self.rf_sizes = []\n",
    "        self.rf_size = None\n",
    "        self.layer_counter = 0\n",
    "        self.strides = [1]\n",
    "        self.kernel_sizes = [1]\n",
    "        self.need_convsersion = (nn.Conv2d, nn.AvgPool2d, nn.MaxPool2d)\n",
    "        \n",
    "        self.register_forward_hook_to_layers(self.model)\n",
    "    \n",
    "    def _clip(self, val, vmin, vmax):\n",
    "        \"\"\"Limits value to be vmin <= val <= vmax\"\"\"\n",
    "        if vmin > vmax:\n",
    "            raise Exception(\"vmin should be smaller than vmax.\")\n",
    "        val = min(val, vmax)\n",
    "        val = max(val, vmin)\n",
    "        return val\n",
    "    \n",
    "    def _calculate_rf(self, stride, kernel_size):\n",
    "        \"\"\"Calculates the RF size of this layer.\"\"\"\n",
    "        if isinstance(stride, tuple):\n",
    "            stride = stride[0]\n",
    "        if isinstance(kernel_size, tuple):\n",
    "            kernel_size = kernel_size[0]\n",
    "\n",
    "        tmp_rf_size = kernel_size\n",
    "        for s, ks in zip(self.strides[::-1], self.kernel_sizes[::-1]):\n",
    "            \"\"\"\n",
    "            The function:\n",
    "\n",
    "                tmp_rf_size = ks_{n-1} + (ks_{n} - 1) * s_{n-1}\n",
    "\n",
    "            calculates how many units of the previous layer (n-1) are covered\n",
    "            by ks_{n} units of the current layer (n). When applied iteratively\n",
    "            using the kernel_size (ks) and stride (s) from the current layer to\n",
    "            the pixel layer, we can calcualte how many pixels are ultimately\n",
    "            projected to one unit in the current layer.\n",
    "\n",
    "            For example:\n",
    "            conv1 in    * * * * * * * * * * * * * * * * *\n",
    "            --------    - - - - -\n",
    "            ks1 = 5         | - - - - -\n",
    "            s1  = 3         |     | - - - - -\n",
    "            rf1 = 5         |     |     | - - - - -\n",
    "                            |     |     |     | - - - - -\n",
    "                            |     |     |     |     |\n",
    "                            V     V     V     V     V\n",
    "            conv1 out       *     *     *     *     *\n",
    "            ---------       -     -     -\n",
    "            ks2 = 3               -     -     -\n",
    "            s2  = 1               |     -     -     -\n",
    "            rf2 = 11              |     |     |\n",
    "                                  V     V     V\n",
    "                                  *     *     *\n",
    "                                  -     -     -\n",
    "                                        |\n",
    "                                        V\n",
    "            conv2 out                   *\n",
    "            ---------\n",
    "            ks3 = 3\n",
    "            rf3 = 15\n",
    "            \"\"\"\n",
    "            tmp_rf_size = ks + (tmp_rf_size - 1) * s\n",
    "            \n",
    "        self.strides.append(stride)\n",
    "        self.kernel_sizes.append(kernel_size)\n",
    "        return tmp_rf_size\n",
    "\n",
    "    def hook_function(self, module, ten_in, ten_out):\n",
    "        if isinstance(module, self.need_convsersion):\n",
    "            self.rf_size = self._calculate_rf(module.stride, module.kernel_size)\n",
    "            self.rf_size = self._clip(self.rf_size, 0, self.image_size)\n",
    "        \n",
    "        if isinstance(module, nn.Conv2d):\n",
    "            # Record layer index and RF size if module is a conv layers.\n",
    "            self.layer_indices.append(self.layer_counter)\n",
    "            self.rf_sizes.append(self.rf_size)\n",
    "        \n",
    "        self.layer_counter += 1 \n",
    "    \n",
    "    def calculate(self):\n",
    "        \"\"\"\n",
    "        Returns\n",
    "        -------\n",
    "        layer_indicies : [int, ...]\n",
    "            The indices of all convolutional layers.\n",
    "        rf_sizes : [int, ...]\n",
    "            The RF sizes of all convolutional layers.\n",
    "        \"\"\"\n",
    "        dummy_input = torch.ones((1, 3, self.image_size, self.image_size))\n",
    "        _ = self.model(dummy_input)\n",
    "        return self.layer_indices, self.rf_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Here is how we define the 'center' point of a space.\"\"\"\n",
    "\n",
    "def calculate_center(output_size):\n",
    "    \"\"\"\n",
    "    center = (output_size - 1)//2.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    output_size : int or (int, int)\n",
    "        The size of the output maps in (height, width) format.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    The index (int) or indices (int, int) of the spatial center.\n",
    "    \"\"\"\n",
    "    if isinstance(output_size, (tuple, list, np.ndarray)):\n",
    "        if len(output_size) != 2:\n",
    "            raise ValueError(\"output_size should have a length of 2.\")\n",
    "        c1 = calculate_center(output_size[0])\n",
    "        c2 = calculate_center(output_size[1])\n",
    "        return c1, c2\n",
    "    else:\n",
    "        return (output_size - 1)//2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The RF of the early layers are small, so it will be a waste of computational\n",
    "resources to present the full image of size 227x277. The following function\n",
    "calculates the image sizes (of all conv layers) that are just big enough to\n",
    "fit the RFs of the center units, while centering the RF such that the paddings\n",
    "in all directions are the same.\n",
    "\"\"\"\n",
    "\n",
    "def xn_to_center_rf(model):\n",
    "    \"\"\"\n",
    "    Return the input image size xn = yn just big enough to center the RF of\n",
    "    the center units (of all Conv2d layer) in the pixel space. Need this\n",
    "    function because we don't want to use the full image size (227, 227).\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    model : torchvision.models\n",
    "        The neural network.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    xn_list : [int, ...]\n",
    "        A list of xn (which is also yn since we assume RF to be square). \n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    rf_size_object = RfSize(model, 227)\n",
    "    layer_indices, rf_sizes = rf_size_object.calculate()\n",
    "    xn_list = []\n",
    "    \n",
    "    for layer_index, rf_size in zip(layer_indices, rf_sizes):\n",
    "        # Set before and after to different values first\n",
    "        center_response_before = -2\n",
    "        center_response_after = -1\n",
    "        xn = int(rf_size * 1.1)  # add a 10% padding.\n",
    "\n",
    "        # If response before and after perturbation are identical, the unit\n",
    "        # RF is centered.\n",
    "        while(center_response_before != center_response_after):\n",
    "            xn += 1\n",
    "            dummy_input = torch.rand((1, 3, xn, xn))\n",
    "            y, _ = truncated_model(dummy_input, model, layer_index)\n",
    "            yc, xc = calculate_center(y.shape[-2:])\n",
    "            center_response_before = y[0, 0, yc, xc].item()\n",
    "            \n",
    "            # Skip this loop if the paddings on two sides aren't equal.\n",
    "            if ((xn - rf_size)%2 != 0):\n",
    "                continue\n",
    "\n",
    "            padding = (xn - rf_size) // 2\n",
    "            \n",
    "            # Add perturbation to the surrounding padding.\n",
    "            dummy_input[:, :,  :padding,  :padding] = 10000\n",
    "            dummy_input[:, :, -padding:, -padding:] = 10000\n",
    "            y, _ = truncated_model(dummy_input, model, layer_index)\n",
    "            center_response_after = y[0, 0, yc, xc].item()\n",
    "\n",
    "        xn_list.append(xn)\n",
    "    return xn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Functions used for stimulus set creation.\"\"\"\n",
    "\n",
    "\n",
    "#######################################.#######################################\n",
    "#                                                                             #\n",
    "#                             STIM_DAPP_BAR_XYO_C                             #\n",
    "#                                                                             #\n",
    "#  Create dictionary entries for bar stimuli varying in these parameters:     #\n",
    "#    x location                                                               #\n",
    "#    y location                                                               #\n",
    "#    orientation                                                              #\n",
    "#    luminance contrast polarity                                              #\n",
    "#                                                                             #\n",
    "#  The input parameters specify the range of x-values to use, and these       #\n",
    "#  are replicated for the y-range as well.                                    #\n",
    "#  'dori' to use for varying orientation.                                     #\n",
    "#                                                                             #\n",
    "#  The other bar parameters are held fixed:  length, width, anti-aliasing.    #\n",
    "#                                                                             #\n",
    "###############################################################################\n",
    "def stim_dapp_bar_xyo_c(splist,xn,xlist,orilist,blen,bwid,aa,fgvals,bgvals):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    splist  - stimulus parameter list - APPEND TO THIS LIST\\n\n",
    "    xn      - horizontal and vertical image size\\n\n",
    "    xlist   - list of x-coordinates (pix)\\n\n",
    "    orilist - list of orientation values (degr)\\n\n",
    "    blen    - Length of bar (pix)\\n\n",
    "    bwid    - Width of bar (pix)\\n\n",
    "    aa      - Anti-aliasing space constant (pix)\\n\n",
    "    fgvals  - foreground rgb luminance ((int, int, int))\\n\n",
    "    bgval   - backgroudn rgb luminance ((int, int, int))\n",
    "    \"\"\"\n",
    "    yn = xn        # Assuming image is square\n",
    "    ylist = xlist  # Use same coordinates for y grid locations\n",
    "\n",
    "    for i in xlist:\n",
    "        for j in ylist:\n",
    "            for o in orilist:\n",
    "                for fgval, bgval in zip(fgvals, bgvals):\n",
    "                    tp = {\"xn\":xn, \"yn\":yn, \"x0\":i, \"y0\":j, \"theta\":o, \"len\":blen,\n",
    "                        \"wid\":bwid, \"aa\":aa, \"fgval\":fgval, \"bgval\":bgval}\n",
    "                    splist.append(tp)\n",
    "\n",
    "\n",
    "#######################################.#######################################\n",
    "#                                                                             #\n",
    "#                             STIMSET_GRIDX_BARMAP                            #\n",
    "#                                                                             #\n",
    "#  Given a bar length and maximum RF size (both in pixels), return a list     #\n",
    "#  of the x-coordinates of the grid points relative to the center of the      #\n",
    "#  image field.                                                               #\n",
    "#                                                                             #\n",
    "#  I believe the following are true:                                          #\n",
    "#  (1) The center coordinate \"0.0\" will always be included                    #\n",
    "#  (2) There will be an odd number of coordinates                             #\n",
    "#  (3) The extreme coordinates will never be more then half of a bar length   #\n",
    "#      outside of the maximum RF ('max_rf')                                   #\n",
    "#                                                                             #\n",
    "###############################################################################\n",
    "def stimset_gridx_barmap(max_rf,blen):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    max_rf - maximum RF size (pix)\\n\n",
    "    blen   - bar length (pix)\\n\n",
    "    \"\"\"\n",
    "    dx = blen / 2.0                       # Grid spacing is 1/2 of bar length\n",
    "    xmax = round((max_rf/dx) / 2.0) * dx  # Max offset of grid point from center\n",
    "    xlist = np.arange(-xmax,xmax+1,dx)\n",
    "\n",
    "    return xlist\n",
    "\n",
    "\n",
    "def stimset_dict(xn,max_rf,blen_rf_ratios,aspect_ratios,orientations,fgvals,bgvals):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    xn     - stimulus image size (pix)\\n\n",
    "    max_rf - maximum RF size (pix)\\n\n",
    "    blen_rf_ratios - list of bar length to RF ratios([float,...])\\n\n",
    "    aspect_ratios  - list of bar width to bar length ratios ([float,...])\\n\n",
    "    orientations   - list of bar orientations in degrees ([float,...])\\n\n",
    "    fgvals - list of foreground rgb tuples ([(float, float, float), ...])\\n\n",
    "    bgvals - list of background rgb tuples ([(float, float, float), ...])\\n\n",
    "    Returns\n",
    "    -------\n",
    "    splist - List of dictionary entries, one per stimulus image.\n",
    "    \"\"\"\n",
    "    splist = []\n",
    "\n",
    "    #  This constant sets how much blurring occurs at the edges of the bars\n",
    "    aa =  0.5      # Antialias distance (pix)\n",
    "\n",
    "    for blen_rf_ratio in blen_rf_ratios:\n",
    "        blen = blen_rf_ratio * max_rf\n",
    "        xlist = stimset_gridx_barmap(max_rf, blen)\n",
    "        for aspect_ratio in aspect_ratios:\n",
    "            bwid = blen * aspect_ratio\n",
    "            stim_dapp_bar_xyo_c(splist,xn,xlist,orientations,blen,bwid,aa,fgvals,bgvals)\n",
    "\n",
    "    # print(\"  Length of stimulus parameter list:\",len(splist))\n",
    "\n",
    "    return splist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_box(box_indices, linewidth=1):\n",
    "    \"\"\"\n",
    "    Given box indices in (vx_min, hx_min, vx_max, hx_max) format, returns a\n",
    "    matplotlib.patches.Rectangle object. Example usage:\n",
    "\n",
    "        plt.imshow(img)\n",
    "        ax = plt.gca()\n",
    "        rect = make_box((0, 0, 100, 50))\n",
    "        ax.add_patch(rect)\n",
    "\n",
    "    This script plots a red rectangle box with height 100 and width 50 on the\n",
    "    top-left corner of the img.\n",
    "    \"\"\"\n",
    "    vx_min, hx_min, vx_max, hx_max = box_indices\n",
    "    top_left = (hx_min, vx_min)  # (x, y) format.\n",
    "    height = vx_max - vx_min + 1\n",
    "    width = hx_max - hx_min + 1\n",
    "    rect = patches.Rectangle(top_left, width, height, linewidth=linewidth,\n",
    "                             edgecolor='r', facecolor='none')\n",
    "    return rect"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Animation\n",
    "\n",
    "**Warning** This animation contains rapid flashing. Please look away if feeling uncomfortable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[W NNPACK.cpp:79] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "25aaf94629b1413f8f477c922ae9b50c",
       "version_major": 2,
       "version_minor": 0
      },
      "image/png": "",
      "text/html": [
       "\n",
       "            <div style=\"display: inline-block;\">\n",
       "                <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
       "                    Figure\n",
       "                </div>\n",
       "                <img src='' width=800.0/>\n",
       "            </div>\n",
       "        "
      ],
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Get the xn just big enough.\n",
    "xn_list = xn_to_center_rf(model)\n",
    "\n",
    "# Calculate RF sizes of all conv layers\n",
    "rf_size_object = RfSize(model, 227)\n",
    "layer_indices, max_rfs = rf_size_object.calculate()\n",
    "\n",
    "# Define layer-specific info\n",
    "layer_name = f\"conv{conv_i + 1}\"\n",
    "xn = xn_list[conv_i]\n",
    "layer_idx = layer_indices[conv_i]\n",
    "max_rf = max_rfs[conv_i]\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2)\n",
    "fig.set_size_inches(8, 4)\n",
    "fig.suptitle(f\"{layer_name} unit no.{unit_i}\", fontsize=20)\n",
    "ax1.clear()\n",
    "ax2.clear()\n",
    "im1 = ax1.imshow(np.zeros((xn, xn, 3)), vmin=-1, vmax=1)\n",
    "im2 = ax2.imshow(np.zeros((xn, xn, 3)), vmin=0, vmax=1)\n",
    "\n",
    "def bar_generator(splist, model, layer_idx, unit_i):\n",
    "    \"\"\"\n",
    "    Bar generator.\n",
    "    \n",
    "    Note for user - create an instance of this generator and pass it as the\n",
    "                    'frame' keyword argument of animation.FuncAnimation().\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    splist     - bar stimulus parameter list.\\n\n",
    "    model      - neural network.\\n\n",
    "    layer_idx  - index of the layer of interest.\\n\n",
    "    unit_i     - index of the unit of interest.\n",
    "    \n",
    "    Yields\n",
    "    ------\n",
    "    center_response - center repsonse of the unit to bar_i (float).\\n\n",
    "    new_bar  - bar image in (3, yn, xn) format (numpy.ndarray).\\n\n",
    "    bar_sum  - weighted sum of bars in (3, yn, xn) format (numpy.ndarray).\\n\n",
    "    bar_i    - the index of the bar stimuli (int).\\n\n",
    "    num_bars - the total number of bars (int).\n",
    "    \"\"\"\n",
    "    num_bars = len(splist)\n",
    "    xn = splist[0]['xn']\n",
    "    yn = splist[0]['yn']\n",
    "    bar_sum = np.zeros((3, yn, xn))\n",
    "\n",
    "    for bar_i, params in enumerate(splist):\n",
    "        new_bar = stimfr_bar_color(params['xn'], params['yn'],\n",
    "                                    params['x0'], params['y0'],\n",
    "                                    params['theta'], params['len'], params['wid'], \n",
    "                                    params['aa'], *params['fgval'], *params['bgval'])\n",
    "        new_bar_tensor = torch.tensor(np.expand_dims(new_bar, axis=0)).type('torch.FloatTensor')\n",
    "        # Present the patch of bars to the truncated model.\n",
    "        y, _ = truncated_model(new_bar_tensor, model, layer_idx)\n",
    "        yc, xc = calculate_center(y.shape[-2:])\n",
    "        center_response = y[0, unit_i, yc, xc].detach().numpy()\n",
    "        bar_sum += max(center_response, 0) * new_bar\n",
    "        yield center_response, new_bar, bar_sum, bar_i, num_bars\n",
    "\n",
    "def animate_func(frame):\n",
    "    \"\"\"\n",
    "    Animation function used to update the bar mapping images.\n",
    "    \n",
    "    Note for user - pass the function name as the second argument of\n",
    "                    animation.FuncAnimation().\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    frame[0] - center repsonse of the unit to bar_i (float).\\n\n",
    "    frame[1] - new bar image in (3, yn, xn) format (numpy.ndarray).\\n\n",
    "    frame[2] - weighted sum of bars in (3, yn, xn) format (numpy.ndarray).\\n\n",
    "    frame[3] - the index of the bar stimuli (int).\\n\n",
    "    frame[4] - the total number of bars (int).\n",
    "    \"\"\"\n",
    "    im1.set_data((np.transpose(frame[1], (1,2,0))+1)/2)\n",
    "    ax1.set_title(f\"response = {frame[0]:.2f}\")\n",
    "\n",
    "    vmin = frame[2].min()\n",
    "    vmax = frame[2].max()\n",
    "    # im2.set_clim(vmin=vmin, vmax=vmax)\n",
    "    vrange = vmax - vmin\n",
    "    if math.isclose(vrange, 0):  # prevent zero-division\n",
    "        vrange = 1\n",
    "    im2.set_data(((np.transpose(frame[2], (1,2,0)))-vmin)/(vrange))\n",
    "    ax2.set_title(f\"bar no. {frame[3]}/{frame[4]}\")\n",
    "\n",
    "# Create stimulus set\n",
    "splist = stimset_dict(xn,max_rf,blen_rf_ratios,aspect_ratios,orientations,fgvals,bgvals)\n",
    "\n",
    "# Create stimulus generator\n",
    "bar_gen = bar_generator(splist, model, layer_idx, unit_i)\n",
    "\n",
    "\"\"\"\n",
    "The animation will be active as long as the variable 'ani' lives. Therefore,\n",
    "it is important to have the notebook magic: %matplotlib ipympl.\n",
    "!pip install ipympl\n",
    "%matplotlib ipympl\n",
    "\n",
    "Also, the animation function is difficult to debug because it sliences error\n",
    "messages.\n",
    "\n",
    "For details, see https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.FuncAnimation.html#matplotlib.animation.FuncAnimation\n",
    "\"\"\"\n",
    "ani = animation.FuncAnimation(fig, animate_func, frames=bar_gen,\n",
    "                              interval=animation_interval_ms, save_count=0,\n",
    "                              cache_frame_data=False, repeat=False)\n",
    "\n",
    "padding = (xn - max_rf)//2\n",
    "rect1 = make_box((padding-1, padding-1, padding+max_rf-1, padding+max_rf-1))\n",
    "ax1.add_patch(rect1)\n",
    "rect2 = make_box((padding-1, padding-1, padding+max_rf-1, padding+max_rf-1))\n",
    "ax2.add_patch(rect2)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "944af25d1661a4e0ae0843253367de58aba60b991b9300416bbd1a841a7ac55a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}