MVA-2021 / dl_in_practice / hw2_gradcam / Grad_CAM.ipynb
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.1"
    "colab": {
      "name": "Grad_CAM_no_solution.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
  "cells": [
      "cell_type": "markdown",
      "metadata": {
        "id": "Hslsm66juEKM"
      "source": [
        "## Visualization of CNN: Grad-CAM\n",
        "* **Objective**: Convolutional Neural Networks are widely used on computer vision. It is powerful for processing grid-like data. However we hardly know how and why it works, due to the lack of decomposability into individually intuitive components. In this assignment, we will introduce the Grad-CAM which visualizes the heatmap of input images by highlighting the important region for visual question answering(VQA) task.\n",
        "* **To be submitted**: this notebook in two weeks, **cleaned** (i.e. without results, for file size reasons: `menu > kernel > restart and clean`), in a state ready to be executed (if one just presses 'Enter' till the end, one should obtain all the results for all images) with a few comments at the end. No additional report, just the notebook!\n",
        "* NB: if `PIL` is not installed, try `conda install pillow`.\n"
      "cell_type": "code",
      "metadata": {
        "id": "PapIrgsiuEKR"
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import torchvision.transforms as transforms\n",
        "from PIL import Image\n",
        "import warnings\n",
        "import matplotlib.pyplot as plt\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "VD4G5LOEHwcG"
      "source": [
        "import cv2"
      "execution_count": null,
      "outputs": []
      "cell_type": "markdown",
      "metadata": {
        "id": "mlT1eN2RuEKS"
      "source": [
        "### Visual Question Answering problem\n",
        "Given an image and a question in natural language, the model choose the most likely answer from 3 000 classes according to the content of image. The VQA task is indeed a multi-classificaition problem.\n",
        "<img src=\"vqa_model.PNG\">\n",
        "We provide you a pretrained model `vqa_resnet` for VQA tasks."
      "cell_type": "code",
      "metadata": {
        "id": "GARRQFOjuEKU"
      "source": [
        "# load model\n",
        "from load_model import load_model\n",
        "vqa_resnet = load_model()"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "zdpiRJkJuEKW"
      "source": [
        "print(vqa_resnet) # for more information "
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "IkPs88LauEKW"
      "source": [
        "checkpoint = '2017-08-04_00.55.19.pth'\n",
        "saved_state = torch.load(checkpoint, map_location=device)\n",
        "# reading vocabulary from saved model\n",
        "vocab = saved_state['vocab']\n",
        "# reading word tokens from saved model\n",
        "token_to_index = vocab['question']\n",
        "# reading answers from saved model\n",
        "answer_to_index = vocab['answer']\n",
        "num_tokens = len(token_to_index) + 1\n",
        "# reading answer classes from the vocabulary\n",
        "answer_words = ['unk'] * len(answer_to_index)\n",
        "for w, idx in answer_to_index.items():\n",
        "    answer_words[idx]=w"
      "execution_count": null,
      "outputs": []
      "cell_type": "markdown",
      "metadata": {
        "id": "4JI3k_E2uEKY"
      "source": [
        "### Inputs\n",
        "In order to use the pretrained model, the input image should be normalized using `mean = [0.485, 0.456, 0.406]`, and `std = [0.229, 0.224, 0.225]`, and be resized as `(448, 448)`. You can call the function `image_to_features` to achieve image preprocessing. For input question, the function `encode_question` is provided to encode the question into a vector of indices. You can also use `preprocess` function for both image and question preprocessing."
      "cell_type": "code",
      "metadata": {
        "id": "SSoVgneNuEKZ"
      "source": [
        "def get_transform(target_size, central_fraction=1.0):\n",
        "    return transforms.Compose([\n",
        "        transforms.Scale(int(target_size / central_fraction)),\n",
        "        transforms.CenterCrop(target_size),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                             std=[0.229, 0.224, 0.225]),\n",
        "    ])"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "g4hbpHuouEKa"
      "source": [
        "def encode_question(question):\n",
        "    \"\"\" Turn a question into a vector of indices and a question length \"\"\"\n",
        "    question_arr = question.lower().split()\n",
        "    vec = torch.zeros(len(question_arr), device=device).long()\n",
        "    for i, token in enumerate(question_arr):\n",
        "        index = token_to_index.get(token, 0)\n",
        "        vec[i] = index\n",
        "    return vec, torch.tensor(len(question_arr), device=device)"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "dAZI1uTJuEKa"
      "source": [
        "# preprocess requires the dir_path of an image and the associated question. \n",
        "#It returns the spectific input form which can be used directly by vqa model. \n",
        "def preprocess(dir_path, question):\n",
        "    q, q_len = encode_question(question)\n",
        "    img ='RGB')\n",
        "    image_size = 448  # scale image to given size and center\n",
        "    central_fraction = 1.0\n",
        "    transform = get_transform(image_size, central_fraction=central_fraction)\n",
        "    img_transformed = transform(img)\n",
        "    img_features = img_transformed.unsqueeze(0).to(device)\n",
        "    \n",
        "    inputs = (img_features, q.unsqueeze(0), q_len.unsqueeze(0))\n",
        "    return inputs"
      "execution_count": null,
      "outputs": []
      "cell_type": "markdown",
      "metadata": {
        "id": "BbW1kTlZuEKb"
      "source": [
        "We provide you two pictures and some question-answers."
      "cell_type": "code",
      "metadata": {
        "id": "aoNOgPhXuEKc"
      "source": [
        "Question1 = 'What animal'\n",
        "Answer1 = ['dog','cat' ]\n",
        "indices1 = [answer_to_index[ans] for ans in Answer1]# The indices of category \n",
        "img1 ='dog_cat.png')\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "77Q78qViuEKc"
      "source": [
        "dir_path = 'dog_cat.png' \n",
        "inputs = preprocess(dir_path, Question1)\n",
        "ans = vqa_resnet(*inputs) # use model to predict the answer\n",
        "answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "UkZsOxH0uEKc"
      "source": [
        "Question2 = 'What color'\n",
        "Answer2 = ['green','yellow' ]\n",
        "indices2 = [answer_to_index[ans] for ans in Answer2]\n",
        "img2 ='hydrant.png')\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "Byih1NwIuEKd"
      "source": [
        "dir_path = 'hydrant.png' \n",
        "inputs = preprocess(dir_path, Question2)\n",
        "ans = vqa_resnet(*inputs) # use model to predict the answer\n",
        "answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "markdown",
      "metadata": {
        "id": "7N6shd0ruEKe"
      "source": [
        "### Grad-CAM \n",
        "* **Overview:** Given an image with a question, and a category (‘dog’) as input, we forward propagate the image through the model to obtain the `raw class scores` before softmax. The gradients are set to zero for all classes except the desired class (dog), which is set to 1. This signal is then backpropagated to the `rectified convolutional feature map` of interest, where we can compute the coarse Grad-CAM localization (blue heatmap).\n",
        "* **To Do**: Define your own function Grad_CAM to achieve the visualization of the two images. For each image, consider the answers we provided as the desired classes. Compare the heatmaps of different answers, and conclude. \n",
        "* **Hints**: \n",
        " + We need to record the output and grad_output of the feature maps to achieve Grad-CAM. In pytorch, the function `Hook` is defined for this purpose. Read the tutorial of [hook]( carefully. \n",
        " + The pretrained model `vqa_resnet` doesn't have the activation function after its last layer, the output is indeed the `raw class scores`, you can use it directly. Run \"print(vqa_resnet)\" to get more information on VGG model.\n",
        " + The last CNN layer of the model is: `vqa_resnet.resnet_layer4.r_model.layer4[2].conv3` \n",
        " + The size of feature maps is 14x14, so as your heatmap. You need to project the heatmap to the original image(224x224) to have a better observation. The function `cv2.resize()` may help.  \n",
        " + Here is the link of the paper [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization]("
      "cell_type": "markdown",
      "metadata": {
        "id": "rC-bzCdIuEKe"
      "source": [
        "<img src=\"grad_cam.png\">"
      "cell_type": "code",
      "metadata": {
        "id": "nRKEVO37uEKe"
      "source": [
        "# Preparing the hook in the vqa_resnet network\r\n",
        "def forward_hook(self, input, output):\r\n",
        "    global layer_activations\r\n",
        "    layer_activations = output\r\n",
        "def backward_hook(self, input, output):\r\n",
        "    global layer_gradients\r\n",
        "    layer_gradients = output[0]\r\n",
        "  \r\n",
        "handle_forward = vqa_resnet.resnet_layer4.r_model.layer4[2].conv3.register_forward_hook(forward_hook)\r\n",
        "handle_backward = vqa_resnet.resnet_layer4.r_model.layer4[2].conv3.register_backward_hook(backward_hook)"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "pr5lNX3z4T85"
      "source": [
        "# handle_forward.remove()\r\n",
        "# handle_backward.remove()"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "6mit1YfULcdI"
      "source": [
        "def grad_cam(question, answer_idx=None):\r\n",
        "    ''' Takes the question as input and saves the associated heatmap in the main directory '''\r\n",
        "    # prepare the inputs\r\n",
        "    if question == 2:\r\n",
        "        dir_path = 'hydrant.png' \r\n",
        "        inputs = preprocess(dir_path, Question2)\r\n",
        "    elif question == 1:\r\n",
        "        dir_path = 'dog_cat.png' \r\n",
        "        inputs = preprocess(dir_path, Question1)\r\n",
        "    else:\r\n",
        "        return 'Not a valid question'\r\n",
        "    # propagate through the model to obtain the raw score outputs\r\n",
        "    print('Propagating to get the predictions ...')\r\n",
        "    ans = vqa_resnet(*inputs)\r\n",
        "    if answer_idx is None:\r\n",
        "        answer_idx = np.argmax(F.softmax(ans, dim=1).data.numpy())\r\n",
        "    print('Predicted answer: ' + answer_words[answer_idx])\r\n",
        "    print('Preparing heatmap ...')\r\n",
        "    # backpropagate from the activated answer\r\n",
        "    vqa_resnet.zero_grad()\r\n",
        "    ans[:, answer_idx].backward()\r\n",
        "    # mean over the channels of the gradients\r\n",
        "    avg_gradients = torch.mean(layer_gradients, dim=[0, 2, 3])\r\n",
        "    # weight the channels according to the gradient\r\n",
        "    for i in range(layer_activations.shape[1]):\r\n",
        "        layer_activations[:, i, :, :] *= avg_gradients[i]\r\n",
        "    # average the channels of the activations\r\n",
        "    heatmap = torch.mean(layer_activations, dim=1).squeeze().detach().numpy()\r\n",
        "    # rescale the output so that it represents an image\r\n",
        "    print('Heatmap Preview:')\r\n",
        "    plt.matshow(heatmap.squeeze())\r\n",
        "    # heatmap = heatmap - np.amin(heatmap)\r\n",
        "    heatmap = np.maximum(heatmap, 0)\r\n",
        "    heatmap /= np.amax(heatmap)\r\n",
        "    img = cv2.imread(dir_path)\r\n",
        "    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))\r\n",
        "    heatmap = np.uint8(255 * heatmap)\r\n",
        "    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)\r\n",
        "    projected_heatmap = heatmap * 0.4 + img\r\n",
        "    save_path = dir_path.split('.')[0] + '_' +answer_words[answer_idx]+'_heatmap.jpg'\r\n",
        "    if cv2.imwrite(save_path, projected_heatmap):\r\n",
        "        print('Heatmap succesfully saved at '+ save_path)\r\n",
        "    else:\r\n",
        "        print('Error in saving the heatmap')"
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "MFcOB2P4Hvev"
      "source": [
        "question = 1\r\n",
        "print('-> Question '+str(question))\r\n",
        "question = 1\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question, answer_idx = answer_to_index['cat'])\r\n",
        "question = 2\r\n",
        "print('-> Question '+str(question))\r\n",
        "question = 2\r\n",
        "print('-> Question '+str(question))\r\n",
        "grad_cam(question, answer_idx = answer_to_index['yellow'])\r\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "VHhtTXOE7c9s"
      "source": [
        "hm_dog ='dog_cat_dog_heatmap.jpg')\r\n",
        "hm_cat ='dog_cat_cat_heatmap.jpg')\r\n",
        "plt.subplot(1, 2, 1)\r\n",
        "plt.title('Heatmap for Dog answer')\r\n",
        "plt.subplot(1, 2, 2)\r\n",
        "plt.title('Heatmap for Cat answer')\r\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "code",
      "metadata": {
        "id": "-QiattRZBWKe"
      "source": [
        "hm_green ='hydrant_green_heatmap.jpg')\r\n",
        "hm_yellow ='hydrant_yellow_heatmap.jpg')\r\n",
        "plt.subplot(1, 2, 1)\r\n",
        "plt.title('Heatmap for Green answer')\r\n",
        "plt.subplot(1, 2, 2)\r\n",
        "plt.title('Heatmap for Yellow answer')\r\n",
      "execution_count": null,
      "outputs": []
      "cell_type": "markdown",
      "metadata": {
        "id": "B9Oo1BYfQ77e"
      "source": [
        "As we can see, the heatmap has a higher value when it is near to the object subject to the answer (dog's face, cat's face, etc.). It can be a good way to see if the network focuses on the good place of the image, and thus, is a great tool for explainability. "