MVA-2021 / dl_in_practice / hw2_gradcam / Grad_CAM.ipynb
Grad_CAM.ipynb
Raw
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.1"
    },
    "colab": {
      "name": "Grad_CAM_no_solution.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hslsm66juEKM"
      },
      "source": [
        "## Visualization of CNN: Grad-CAM\n",
        "* **Objective**: Convolutional Neural Networks are widely used on computer vision. It is powerful for processing grid-like data. However we hardly know how and why it works, due to the lack of decomposability into individually intuitive components. In this assignment, we will introduce the Grad-CAM which visualizes the heatmap of input images by highlighting the important region for visual question answering(VQA) task.\n",
        "\n",
        "* **To be submitted**: this notebook in two weeks, **cleaned** (i.e. without results, for file size reasons: `menu > kernel > restart and clean`), in a state ready to be executed (if one just presses 'Enter' till the end, one should obtain all the results for all images) with a few comments at the end. No additional report, just the notebook!\n",
        "\n",
        "* NB: if `PIL` is not installed, try `conda install pillow`.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PapIrgsiuEKR"
      },
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "from PIL import Image\n",
        "\n",
        "import warnings\n",
        "import matplotlib.pyplot as plt\n",
        "warnings.filterwarnings('ignore')\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VD4G5LOEHwcG"
      },
      "source": [
        "import cv2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mlT1eN2RuEKS"
      },
      "source": [
        "### Visual Question Answering problem\n",
        "Given an image and a question in natural language, the model choose the most likely answer from 3 000 classes according to the content of image. The VQA task is indeed a multi-classificaition problem.\n",
        "<img src=\"vqa_model.PNG\">\n",
        "\n",
        "We provide you a pretrained model `vqa_resnet` for VQA tasks."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GARRQFOjuEKU"
      },
      "source": [
        "# load model\n",
        "from load_model import load_model\n",
        "vqa_resnet = load_model()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zdpiRJkJuEKW"
      },
      "source": [
        "print(vqa_resnet) # for more information "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IkPs88LauEKW"
      },
      "source": [
        "checkpoint = '2017-08-04_00.55.19.pth'\n",
        "saved_state = torch.load(checkpoint, map_location=device)\n",
        "# reading vocabulary from saved model\n",
        "vocab = saved_state['vocab']\n",
        "\n",
        "# reading word tokens from saved model\n",
        "token_to_index = vocab['question']\n",
        "\n",
        "# reading answers from saved model\n",
        "answer_to_index = vocab['answer']\n",
        "\n",
        "num_tokens = len(token_to_index) + 1\n",
        "\n",
        "# reading answer classes from the vocabulary\n",
        "answer_words = ['unk'] * len(answer_to_index)\n",
        "for w, idx in answer_to_index.items():\n",
        "    answer_words[idx]=w"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4JI3k_E2uEKY"
      },
      "source": [
        "### Inputs\n",
        "In order to use the pretrained model, the input image should be normalized using `mean = [0.485, 0.456, 0.406]`, and `std = [0.229, 0.224, 0.225]`, and be resized as `(448, 448)`. You can call the function `image_to_features` to achieve image preprocessing. For input question, the function `encode_question` is provided to encode the question into a vector of indices. You can also use `preprocess` function for both image and question preprocessing."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SSoVgneNuEKZ"
      },
      "source": [
        "def get_transform(target_size, central_fraction=1.0):\n",
        "    return transforms.Compose([\n",
        "        transforms.Scale(int(target_size / central_fraction)),\n",
        "        transforms.CenterCrop(target_size),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                             std=[0.229, 0.224, 0.225]),\n",
        "    ])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g4hbpHuouEKa"
      },
      "source": [
        "def encode_question(question):\n",
        "    \"\"\" Turn a question into a vector of indices and a question length \"\"\"\n",
        "    question_arr = question.lower().split()\n",
        "    vec = torch.zeros(len(question_arr), device=device).long()\n",
        "    for i, token in enumerate(question_arr):\n",
        "        index = token_to_index.get(token, 0)\n",
        "        vec[i] = index\n",
        "    return vec, torch.tensor(len(question_arr), device=device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dAZI1uTJuEKa"
      },
      "source": [
        "# preprocess requires the dir_path of an image and the associated question. \n",
        "#It returns the spectific input form which can be used directly by vqa model. \n",
        "def preprocess(dir_path, question):\n",
        "    q, q_len = encode_question(question)\n",
        "    img = Image.open(dir_path).convert('RGB')\n",
        "    image_size = 448  # scale image to given size and center\n",
        "    central_fraction = 1.0\n",
        "    transform = get_transform(image_size, central_fraction=central_fraction)\n",
        "    img_transformed = transform(img)\n",
        "    img_features = img_transformed.unsqueeze(0).to(device)\n",
        "    \n",
        "    inputs = (img_features, q.unsqueeze(0), q_len.unsqueeze(0))\n",
        "    return inputs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BbW1kTlZuEKb"
      },
      "source": [
        "We provide you two pictures and some question-answers."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aoNOgPhXuEKc"
      },
      "source": [
        "Question1 = 'What animal'\n",
        "Answer1 = ['dog','cat' ]\n",
        "indices1 = [answer_to_index[ans] for ans in Answer1]# The indices of category \n",
        "img1 = Image.open('dog_cat.png')\n",
        "img1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "77Q78qViuEKc"
      },
      "source": [
        "dir_path = 'dog_cat.png' \n",
        "inputs = preprocess(dir_path, Question1)\n",
        "ans = vqa_resnet(*inputs) # use model to predict the answer\n",
        "answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\n",
        "print(answer_words[answer_idx])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UkZsOxH0uEKc"
      },
      "source": [
        "Question2 = 'What color'\n",
        "Answer2 = ['green','yellow' ]\n",
        "indices2 = [answer_to_index[ans] for ans in Answer2]\n",
        "img2 = Image.open('hydrant.png')\n",
        "img2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Byih1NwIuEKd"
      },
      "source": [
        "dir_path = 'hydrant.png' \n",
        "inputs = preprocess(dir_path, Question2)\n",
        "ans = vqa_resnet(*inputs) # use model to predict the answer\n",
        "answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\n",
        "print(answer_words[answer_idx])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7N6shd0ruEKe"
      },
      "source": [
        "### Grad-CAM \n",
        "* **Overview:** Given an image with a question, and a category (‘dog’) as input, we forward propagate the image through the model to obtain the `raw class scores` before softmax. The gradients are set to zero for all classes except the desired class (dog), which is set to 1. This signal is then backpropagated to the `rectified convolutional feature map` of interest, where we can compute the coarse Grad-CAM localization (blue heatmap).\n",
        "\n",
        "\n",
        "* **To Do**: Define your own function Grad_CAM to achieve the visualization of the two images. For each image, consider the answers we provided as the desired classes. Compare the heatmaps of different answers, and conclude. \n",
        "\n",
        "\n",
        "* **Hints**: \n",
        " + We need to record the output and grad_output of the feature maps to achieve Grad-CAM. In pytorch, the function `Hook` is defined for this purpose. Read the tutorial of [hook](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) carefully. \n",
        " + The pretrained model `vqa_resnet` doesn't have the activation function after its last layer, the output is indeed the `raw class scores`, you can use it directly. Run \"print(vqa_resnet)\" to get more information on VGG model.\n",
        " + The last CNN layer of the model is: `vqa_resnet.resnet_layer4.r_model.layer4[2].conv3` \n",
        " + The size of feature maps is 14x14, so as your heatmap. You need to project the heatmap to the original image(224x224) to have a better observation. The function `cv2.resize()` may help.  \n",
        " + Here is the link of the paper [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/pdf/1610.02391.pdf)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rC-bzCdIuEKe"
      },
      "source": [
        "<img src=\"grad_cam.png\">"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nRKEVO37uEKe"
      },
      "source": [
        "# Preparing the hook in the vqa_resnet network\r\n",
        "def forward_hook(self, input, output):\r\n",
        "    global layer_activations\r\n",
        "    layer_activations = output\r\n",
        "\r\n",
        "def backward_hook(self, input, output):\r\n",
        "    global layer_gradients\r\n",
        "    layer_gradients = output[0]\r\n",
        "  \r\n",
        "handle_forward = vqa_resnet.resnet_layer4.r_model.layer4[2].conv3.register_forward_hook(forward_hook)\r\n",
        "handle_backward = vqa_resnet.resnet_layer4.r_model.layer4[2].conv3.register_backward_hook(backward_hook)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pr5lNX3z4T85"
      },
      "source": [
        "# handle_forward.remove()\r\n",
        "# handle_backward.remove()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6mit1YfULcdI"
      },
      "source": [
        "def grad_cam(question, answer_idx=None):\r\n",
        "    ''' Takes the question as input and saves the associated heatmap in the main directory '''\r\n",
        "\r\n",
        "    # prepare the inputs\r\n",
        "    if question == 2:\r\n",
        "        dir_path = 'hydrant.png' \r\n",
        "        inputs = preprocess(dir_path, Question2)\r\n",
        "    elif question == 1:\r\n",
        "        dir_path = 'dog_cat.png' \r\n",
        "        inputs = preprocess(dir_path, Question1)\r\n",
        "    else:\r\n",
        "        return 'Not a valid question'\r\n",
        "\r\n",
        "    # propagate through the model to obtain the raw score outputs\r\n",
        "    print('Propagating to get the predictions ...')\r\n",
        "    ans = vqa_resnet(*inputs)\r\n",
        "    if answer_idx is None:\r\n",
        "        answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\r\n",
        "    print('Predicted answer: ' + answer_words[answer_idx])\r\n",
        "\r\n",
        "    print('Preparing heatmap ...')\r\n",
        "    # backpropagate from the activated answer\r\n",
        "\r\n",
        "    vqa_resnet.zero_grad()\r\n",
        "    ans[:, answer_idx].backward()\r\n",
        "\r\n",
        "    # mean over the channels of the gradients\r\n",
        "    avg_gradients = torch.mean(layer_gradients, dim=[0, 2, 3])\r\n",
        "    # weight the channels according to the gradient\r\n",
        "    for i in range(layer_activations.shape[1]):\r\n",
        "        layer_activations[:, i, :, :] *= avg_gradients[i]\r\n",
        "    # average the channels of the activations\r\n",
        "    heatmap = torch.mean(layer_activations, dim=1).squeeze().detach().numpy()\r\n",
        "    # rescale the output so that it represents an image\r\n",
        "\r\n",
        "    print('Heatmap Preview:')\r\n",
        "    plt.matshow(heatmap.squeeze())\r\n",
        "    plt.show()\r\n",
        "\r\n",
        "    # heatmap = heatmap - np.amin(heatmap)\r\n",
        "    heatmap = np.maximum(heatmap, 0)\r\n",
        "    heatmap /= np.amax(heatmap)\r\n",
        "\r\n",
        "    img = cv2.imread(dir_path)\r\n",
        "    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))\r\n",
        "    heatmap = np.uint8(255 * heatmap)\r\n",
        "    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)\r\n",
        "    projected_heatmap = heatmap * 0.4 + img\r\n",
        "    save_path = dir_path.split('.')[0] + '_' +answer_words[answer_idx]+'_heatmap.jpg'\r\n",
        "    if cv2.imwrite(save_path, projected_heatmap):\r\n",
        "        print('Heatmap succesfully saved at '+ save_path)\r\n",
        "    else:\r\n",
        "        print('Error in saving the heatmap')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MFcOB2P4Hvev"
      },
      "source": [
        "question = 1\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question)\r\n",
        "print('')\r\n",
        "\r\n",
        "question = 1\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question, answer_idx = answer_to_index['cat'])\r\n",
        "print('')\r\n",
        "\r\n",
        "question = 2\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question)\r\n",
        "print('')\r\n",
        "\r\n",
        "question = 2\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question, answer_idx = answer_to_index['yellow'])\r\n",
        "print('')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VHhtTXOE7c9s"
      },
      "source": [
        "hm_dog = Image.open('dog_cat_dog_heatmap.jpg')\r\n",
        "hm_cat = Image.open('dog_cat_cat_heatmap.jpg')\r\n",
        "\r\n",
        "plt.figure(figsize=(10,7))\r\n",
        "\r\n",
        "plt.subplot(1, 2, 1)\r\n",
        "plt.imshow(hm_dog)\r\n",
        "plt.title('Heatmap for Dog answer')\r\n",
        "\r\n",
        "plt.subplot(1, 2, 2)\r\n",
        "plt.imshow(hm_cat)\r\n",
        "plt.title('Heatmap for Cat answer')\r\n",
        "\r\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-QiattRZBWKe"
      },
      "source": [
        "hm_green = Image.open('hydrant_green_heatmap.jpg')\r\n",
        "hm_yellow = Image.open('hydrant_yellow_heatmap.jpg')\r\n",
        "\r\n",
        "plt.figure(figsize=(10,7))\r\n",
        "\r\n",
        "plt.subplot(1, 2, 1)\r\n",
        "plt.imshow(hm_green)\r\n",
        "plt.title('Heatmap for Green answer')\r\n",
        "\r\n",
        "plt.subplot(1, 2, 2)\r\n",
        "plt.imshow(hm_yellow)\r\n",
        "plt.title('Heatmap for Yellow answer')\r\n",
        "\r\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9Oo1BYfQ77e"
      },
      "source": [
        "As we can see, the heatmap has a higher value when it is near to the object subject to the answer (dog's face, cat's face, etc.). It can be a good way to see if the network focuses on the good place of the image, and thus, is a great tool for explainability. "
      ]
    }
  ]
}