MVA-2021 / graphs_in_ml / hw3_spectral_gnn.ipynb
hw3_spectral_gnn.ipynb
Raw
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "dldiy",
      "language": "python",
      "name": "dldiy"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    },
    "colab": {
      "name": "GML_spectral_gnn.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAN8ZbiRIM9C"
      },
      "source": [
        "# Graph ConvNets in PyTorch\n",
        "\n",
        "PyTorch implementation of the NeurIPS'16 paper:\n",
        "Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering\n",
        "M Defferrard, X Bresson, P Vandergheynst\n",
        "Advances in Neural Information Processing Systems, 3844-3852, 2016\n",
        "[ArXiv preprint](https://arxiv.org/abs/1606.09375)\n",
        "\n",
        "Adapted from Xavier Bresson's repo: [spectral_graph_convnets](https://github.com/xbresson/spectral_graph_convnets) for [dataflowr](https://dataflowr.github.io/website/) by [Marc Lelarge](https://www.di.ens.fr/~lelarge/)\n",
        "\n",
        "## objective:\n",
        "\n",
        "The code provides a simple example of graph ConvNets for the MNIST classification task.\n",
        "The graph is a 8-nearest neighbor graph of a 2D grid.\n",
        "The signals on graph are the MNIST images vectorized as $28^2 \\times 1$ vectors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R84-PMawIM9H",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bcf8049f-3913-496e-d11f-e4d65f19b14c"
      },
      "source": [
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torch.nn as nn\n",
        "import collections\n",
        "import time\n",
        "import numpy as np\n",
        "import scipy\n",
        "from functools import partial\n",
        "import os\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    print('cuda available')\n",
        "    dtypeFloat = torch.cuda.FloatTensor\n",
        "    dtypeLong = torch.cuda.LongTensor\n",
        "    torch.cuda.manual_seed(1)\n",
        "else:\n",
        "    print('cuda not available')\n",
        "    dtypeFloat = torch.FloatTensor\n",
        "    dtypeLong = torch.LongTensor\n",
        "    torch.manual_seed(1)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cuda available\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YPM7M4H7IM9I"
      },
      "source": [
        "## Download the data\n",
        "\n",
        "If you are running on colab, follow the instructions below.\n",
        "\n",
        "If you cloned the repo, go directly to the tempory hack."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MJtTTTgeIM9I"
      },
      "source": [
        "### Colab setting\n",
        "\n",
        "If you run this notebook on colab, please uncomment (and run) the following cells."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5ka4jBE3IM9I",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "06b59581-5757-42ea-8a0c-ac4c7cf17c6a"
      },
      "source": [
        "!wget www.di.ens.fr/~lelarge/graphs.tar.gz"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2021-03-16 20:24:19--  http://www.di.ens.fr/~lelarge/graphs.tar.gz\n",
            "Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14\n",
            "Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://www.di.ens.fr/~lelarge/graphs.tar.gz [following]\n",
            "--2021-03-16 20:24:19--  https://www.di.ens.fr/~lelarge/graphs.tar.gz\n",
            "Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: unspecified [application/x-gzip]\n",
            "Saving to: ‘graphs.tar.gz’\n",
            "\n",
            "graphs.tar.gz           [              <=>   ]  66.42M  22.9MB/s    in 2.9s    \n",
            "\n",
            "2021-03-16 20:24:22 (22.9 MB/s) - ‘graphs.tar.gz’ saved [69647028]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2pKymA5sIM9J",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9d4920fb-c964-4b91-9957-ca88b190f00c"
      },
      "source": [
        "!tar -zxvf graphs.tar.gz"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "graphs/\n",
            "graphs/lib/\n",
            "graphs/lib/grid_graph.py\n",
            "graphs/lib/coarsening.py\n",
            "graphs/.spectral_gnn.ipynb.swp\n",
            "graphs/mnist/\n",
            "graphs/mnist/temp/\n",
            "graphs/mnist/temp/MNIST/\n",
            "graphs/mnist/temp/MNIST/raw/\n",
            "graphs/mnist/temp/MNIST/raw/t10k-images-idx3-ubyte\n",
            "graphs/mnist/temp/MNIST/raw/t10k-labels-idx1-ubyte\n",
            "graphs/mnist/temp/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
            "graphs/mnist/temp/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
            "graphs/mnist/temp/MNIST/raw/train-images-idx3-ubyte\n",
            "graphs/mnist/temp/MNIST/raw/train-images-idx3-ubyte.gz\n",
            "graphs/mnist/temp/MNIST/raw/train-labels-idx1-ubyte\n",
            "graphs/mnist/temp/MNIST/raw/train-labels-idx1-ubyte.gz\n",
            "graphs/mnist/temp/MNIST/processed/\n",
            "graphs/mnist/temp/MNIST/processed/test.pt\n",
            "graphs/mnist/temp/MNIST/processed/training.pt\n",
            "graphs/mnist/temp/MNIST.tar.gz\n",
            "graphs/spectral_gnn.ipynb\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rn2t4_jpIM9K",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "873fe3be-a79c-45e5-c157-7a4b2f2d6ed2"
      },
      "source": [
        "%cd graphs"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/graphs\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OL9d2fhVIM9K"
      },
      "source": [
        "### temporary hack \n",
        "\n",
        "Unecessary if running on colab (or if you already have MNIST), see this [issue](https://github.com/pytorch/vision/issues/3497)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "exeFC4D2IM9L"
      },
      "source": [
        "!mkdir mnist\n",
        "%cd mnist\n",
        "!mkdir temp\n",
        "%cd temp\n",
        "!wget www.di.ens.fr/~lelarge/MNIST.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZPLcJokiIM9L"
      },
      "source": [
        "!tar -zxvf MNIST.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e3V0ifgWIM9M"
      },
      "source": [
        "%cd ..\n",
        "%cd .."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s-QfoaerIM9M"
      },
      "source": [
        "## Loading the data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fU9R11odIM9M",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d55c04c2-a6aa-45d1-bb7b-87b877972756"
      },
      "source": [
        "def check_mnist_dataset_exists(path_data='./'):\n",
        "    flag_train_data = os.path.isfile(path_data + 'mnist/train_data.pt') \n",
        "    flag_train_label = os.path.isfile(path_data + 'mnist/train_label.pt') \n",
        "    flag_test_data = os.path.isfile(path_data + 'mnist/test_data.pt') \n",
        "    flag_test_label = os.path.isfile(path_data + 'mnist/test_label.pt') \n",
        "    if flag_train_data==False or flag_train_label==False or flag_test_data==False or flag_test_label==False:\n",
        "        print('MNIST dataset preprocessing...')\n",
        "        import torchvision\n",
        "        import torchvision.transforms as transforms\n",
        "        trainset = torchvision.datasets.MNIST(root=path_data + 'mnist/temp', train=True,\n",
        "                                                download=True, transform=transforms.ToTensor())\n",
        "        testset = torchvision.datasets.MNIST(root=path_data + 'mnist/temp', train=False,\n",
        "                                               download=True, transform=transforms.ToTensor())\n",
        "        train_data=torch.Tensor(60000,28,28)\n",
        "        train_label=torch.LongTensor(60000)\n",
        "        for idx , example in enumerate(trainset):\n",
        "            train_data[idx]=example[0].squeeze()\n",
        "            train_label[idx]=example[1]\n",
        "        torch.save(train_data,path_data + 'mnist/train_data.pt')\n",
        "        torch.save(train_label,path_data + 'mnist/train_label.pt')\n",
        "        test_data=torch.Tensor(10000,28,28)\n",
        "        test_label=torch.LongTensor(10000)\n",
        "        for idx , example in enumerate(testset):\n",
        "            test_data[idx]=example[0].squeeze()\n",
        "            test_label[idx]=example[1]\n",
        "        torch.save(test_data,path_data + 'mnist/test_data.pt')\n",
        "        torch.save(test_label,path_data + 'mnist/test_label.pt')\n",
        "    return path_data\n",
        "\n",
        "\n",
        "_ = check_mnist_dataset_exists()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "MNIST dataset preprocessing...\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9-FvItWZIM9N",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bb120523-37fd-45d9-9f77-30c7477654c7"
      },
      "source": [
        "#if you want to play with a small dataset (for cpu), uncomment.\n",
        "#nb_selected_train_data = 500\n",
        "#nb_selected_test_data = 100\n",
        "\n",
        "train_data=torch.load('mnist/train_data.pt').reshape(60000,784).numpy()\n",
        "#train_data = train_data[:nb_selected_train_data,:]\n",
        "print(train_data.shape)\n",
        "\n",
        "train_labels=torch.load('mnist/train_label.pt').numpy()\n",
        "#train_labels = train_labels[:nb_selected_train_data]\n",
        "print(train_labels.shape)\n",
        "\n",
        "test_data=torch.load('mnist/test_data.pt').reshape(10000,784).numpy()\n",
        "#test_data = test_data[:nb_selected_test_data,:]\n",
        "print(test_data.shape)\n",
        "\n",
        "test_labels=torch.load('mnist/test_label.pt').numpy()\n",
        "#test_labels = test_labels[:nb_selected_test_data]\n",
        "print(test_labels.shape)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(60000, 784)\n",
            "(60000,)\n",
            "(10000, 784)\n",
            "(10000,)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z4d8RsqlIM9O",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1cdc7afa-a9da-4029-887d-379d0d05c39d"
      },
      "source": [
        "from lib.grid_graph import grid_graph\n",
        "from lib.coarsening import coarsen, HEM, compute_perm, perm_adjacency\n",
        "from lib.coarsening import perm_data\n",
        "\n",
        "# Construct graph\n",
        "t_start = time.time()\n",
        "grid_side = 28\n",
        "number_edges = 8\n",
        "metric = 'euclidean'\n",
        "\n",
        "\n",
        "######## YOUR GRAPH ADJACENCY MATRIX HERE ########\n",
        "A = grid_graph(grid_side,number_edges,metric) # create graph of Euclidean grid\n",
        "######## YOUR GRAPH ADJACENCY MATRIX HERE ########\n",
        "print('Execution time: {:.2f}s'.format(time.time() - t_start))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "nb edges:  6396\n",
            "Execution time: 0.09s\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x1bL7L1JIM9P"
      },
      "source": [
        "def laplacian(W, normalized=True):\n",
        "    \"\"\"Return graph Laplacian\"\"\"\n",
        "    I = scipy.sparse.identity(W.shape[0], dtype=W.dtype)\n",
        "\n",
        "    #W += I\n",
        "    # Degree matrix.\n",
        "    d = W.sum(axis=0)\n",
        "\n",
        "    # Laplacian matrix.\n",
        "    if not normalized:\n",
        "        D = scipy.sparse.diags(d.A.squeeze(), 0)\n",
        "        L = D - W\n",
        "    else:\n",
        "        d = 1 / np.sqrt(d)\n",
        "        D = scipy.sparse.diags(d.A.squeeze(), 0)\n",
        "        I = scipy.sparse.identity(d.size, dtype=W.dtype)\n",
        "        L = I - D @ W @ D\n",
        "\n",
        "    assert np.abs(L - L.T).mean() < 1e-8\n",
        "    assert type(L) is scipy.sparse.csr.csr_matrix\n",
        "    return L"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6PsIUNxsIM9P"
      },
      "source": [
        "def rescale_L(L, lmax=2):\n",
        "    \"\"\"Rescale Laplacian eigenvalues to [-1,1]\"\"\"\n",
        "    M, M = L.shape\n",
        "    I = scipy.sparse.identity(M, format='csr', dtype=L.dtype)\n",
        "    L /= lmax * 2\n",
        "    # L = 2 * L/lmax\n",
        "    L -= I\n",
        "    return L \n",
        "\n",
        "def lmax_L(L):\n",
        "    \"\"\"Compute largest Laplacian eigenvalue\"\"\"\n",
        "    return scipy.sparse.linalg.eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QSy9EM2KIM9P",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9b43f1cd-60dc-4ca8-c764-eaf10f61fe41"
      },
      "source": [
        "# Compute coarsened graphs\n",
        "t_start = time.time()\n",
        "coarsening_levels = 4\n",
        "\n",
        "L, perm = coarsen(A, coarsening_levels, partial(laplacian, normalized=True))\n",
        "\n",
        "# Compute max eigenvalue of graph Laplacians\n",
        "lmax = []\n",
        "for i in range(coarsening_levels):\n",
        "    lmax.append(lmax_L(L[i]))\n",
        "print('lmax: ' + str([lmax[i] for i in range(coarsening_levels)]))\n",
        "\n",
        "# Reindex nodes to satisfy a binary tree structure\n",
        "train_data = perm_data(train_data, perm)\n",
        "test_data = perm_data(test_data, perm)\n",
        "\n",
        "print('Execution time: {:.2f}s'.format(time.time() - t_start))\n",
        "del perm"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Heavy Edge Matching coarsening with Xavier version\n",
            "Layer 0: M_0 = |V| = 960 nodes (176 added), |E| = 3198 edges\n",
            "Layer 1: M_1 = |V| = 480 nodes (77 added), |E| = 1618 edges\n",
            "Layer 2: M_2 = |V| = 240 nodes (29 added), |E| = 781 edges\n",
            "Layer 3: M_3 = |V| = 120 nodes (7 added), |E| = 388 edges\n",
            "Layer 4: M_4 = |V| = 60 nodes (0 added), |E| = 194 edges\n",
            "lmax: [1.3857539, 1.3440949, 1.2277015, 0.9999999]\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
            "  \n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
            "  \n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
            "  \n",
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
            "  \n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Execution time: 1.34s\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uc0SwCcgIM9R"
      },
      "source": [
        "Here, we implemented the pooling layers and computed the list `L` containing the Laplacians of the graphs for each layer.\n",
        "\n",
        "## <font color='red'>Question 1: what is the size of the various poolings?</font> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WJqeeflxfqYS"
      },
      "source": [
        "We observe that the number of nodes `|V|` decreases by 2 at each layer, meaning that the various poolings have size 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GY4IypYiIM9R"
      },
      "source": [
        "# Graph ConvNet LeNet5\n",
        "\n",
        "## Layers: CL32-MP4-CL64-MP4-FC512-FC10\n",
        "\n",
        "As described above, this network has 2 graph convolutional layers and two pooling layers with size 4.\n",
        "\n",
        "## <font color='red'>Question 2: which graphs will you take in the list `L` for the graph convolutional layers?</font> \n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTMYxUq7JDhx"
      },
      "source": [
        "The first Laplacian to consider is the one of the original data, i.e. **`L[0]`**. Then, there is a pooling layer of size $4$, meaning that the next graph to be considered contains the original number of nodes divided by $4$. Therefore, the other Laplacian to consider is **`L[2]`**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PH0WoS47I0Zm"
      },
      "source": [
        "In the code below, you will need to complete the `graph_conv_cheby` and the `graph_max_pool`.\n",
        "\n",
        "Hint: each time you permute dimenstions, it is safe to add a `contiguous` like below:\n",
        "`x0 = x.permute(1,2,0).contiguous()` see [here](https://discuss.pytorch.org/t/call-contiguous-after-every-permute-call/13190/2) "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UKiPIVdKIM9R"
      },
      "source": [
        "class Graph_ConvNet_LeNet5(nn.Module):\n",
        "    \n",
        "    def __init__(self, net_parameters):\n",
        "        \n",
        "        print('Graph ConvNet: LeNet5')\n",
        "        \n",
        "        super(Graph_ConvNet_LeNet5, self).__init__()\n",
        "        \n",
        "        # parameters\n",
        "        D, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F = net_parameters\n",
        "        FC1Fin = CL2_F*(D//16)\n",
        "        \n",
        "        # graph CL1\n",
        "        self.cl1 = nn.Linear(CL1_K, CL1_F) \n",
        "        self.init_layers(self.cl1, CL1_K, CL1_F)\n",
        "        self.CL1_K = CL1_K; self.CL1_F = CL1_F; \n",
        "        \n",
        "        # graph CL2\n",
        "        self.cl2 = nn.Linear(CL2_K*CL1_F, CL2_F) \n",
        "        self.init_layers(self.cl2, CL2_K*CL1_F, CL2_F)\n",
        "        self.CL2_K = CL2_K; self.CL2_F = CL2_F; \n",
        "\n",
        "        # FC1\n",
        "        self.fc1 = nn.Linear(FC1Fin, FC1_F) \n",
        "        self.init_layers(self.fc1, FC1Fin, FC1_F)\n",
        "        self.FC1Fin = FC1Fin\n",
        "        \n",
        "        # FC2\n",
        "        self.fc2 = nn.Linear(FC1_F, FC2_F)\n",
        "        self.init_layers(self.fc2, FC1_F, FC2_F)\n",
        "\n",
        "        # nb of parameters\n",
        "        nb_param = CL1_K* CL1_F + CL1_F          # CL1\n",
        "        nb_param += CL2_K* CL1_F* CL2_F + CL2_F  # CL2\n",
        "        nb_param += FC1Fin* FC1_F + FC1_F        # FC1\n",
        "        nb_param += FC1_F* FC2_F + FC2_F         # FC2\n",
        "        print('nb of parameters=',nb_param,'\\n')\n",
        "        \n",
        "        \n",
        "    def init_layers(self, W, Fin, Fout):\n",
        "\n",
        "        scale = np.sqrt( 2.0/ (Fin+Fout) )\n",
        "        W.weight.data.uniform_(-scale, scale)\n",
        "        W.bias.data.fill_(0.0)\n",
        "\n",
        "        return W\n",
        "        \n",
        "        \n",
        "    def graph_conv_cheby(self, x, cl, L, lmax, Fout, K):\n",
        "        # parameters\n",
        "        # B = batch size\n",
        "        # V = nb vertices\n",
        "        # Fin = nb input features\n",
        "        # Fout = nb output features\n",
        "        # K = Chebyshev order & support size\n",
        "        B, V, Fin = x.size(); B, V, Fin = int(B), int(V), int(Fin) \n",
        "\n",
        "        # rescale Laplacian\n",
        "        # lmax = lmax_L(L)\n",
        "        # L = rescale_L(L, lmax) \n",
        "        L = 0.25 * (L - scipy.sparse.identity(V, dtype=L.dtype))\n",
        "        \n",
        "        # convert scipy sparse matric L to pytorch\n",
        "        L = L.tocoo()\n",
        "        indices = np.column_stack((L.row, L.col)).T \n",
        "        indices = indices.astype(np.int64)\n",
        "        indices = torch.from_numpy(indices)\n",
        "        indices = indices.type(torch.LongTensor)\n",
        "        L_data = L.data.astype(np.float32)\n",
        "        L_data = torch.from_numpy(L_data) \n",
        "        L_data = L_data.type(torch.FloatTensor)\n",
        "        L = torch.sparse.FloatTensor(indices, L_data, torch.Size(L.shape))\n",
        "        L.requires_grad_(False)\n",
        "        if torch.cuda.is_available():\n",
        "            L = L.cuda()\n",
        "        \n",
        "        # transform to Chebyshev basis\n",
        "        x = x.permute(1,2,0).contiguous()\n",
        "        x = x.view([V, Fin*B])\n",
        "        x = x.unsqueeze(0).repeat(K, 1, 1)\n",
        "        if K > 1: \n",
        "            x[1] = torch.sparse.mm(L,x[0])\n",
        "        for k in range(2, K):\n",
        "            x[k] = 2 * torch.sparse.mm(L,x[k-1]) - x[k-2]\n",
        "        x = x.view([K, V, Fin, B]).permute(3,1,2,0).contiguous()     \n",
        "        x = x.view([B*V, Fin*K])\n",
        "\n",
        "        x = cl(x) \n",
        "        x = x.view([B, V, Fout])\n",
        "\n",
        "        return x\n",
        "        \n",
        "        \n",
        "    # Max pooling of size p. Must be a power of 2.\n",
        "    def graph_max_pool(self, x, p): \n",
        "        # input B x V x F output B x V/p x F\n",
        "        x = nn.MaxPool1d(p)(x.permute(0,2,1).contiguous())\n",
        "        x = x.permute(0,2,1).contiguous()  \n",
        "        return x    \n",
        "        \n",
        "        \n",
        "    def forward(self, x, d, L, lmax):\n",
        "        # graph CL1\n",
        "        x = x.unsqueeze(2) # B x V x Fin=1  \n",
        "        x = self.graph_conv_cheby(x, self.cl1, L[0], lmax[0], self.CL1_F, self.CL1_K)\n",
        "        x = F.relu(x)\n",
        "        x = self.graph_max_pool(x, 4)\n",
        "        # graph CL2\n",
        "        x = self.graph_conv_cheby(x, self.cl2, L[2], lmax[2], self.CL2_F, self.CL2_K)\n",
        "        x = F.relu(x)\n",
        "        x = self.graph_max_pool(x, 4)\n",
        "        # FC1\n",
        "        x = x.view(-1, self.FC1Fin)\n",
        "        x = self.fc1(x)\n",
        "        x = F.relu(x)\n",
        "        x  = nn.Dropout(d)(x)\n",
        "        # FC2\n",
        "        x = self.fc2(x)\n",
        "        return x\n",
        "        \n",
        "        \n",
        "    def loss(self, y, y_target, l2_regularization):\n",
        "    \n",
        "        loss = nn.CrossEntropyLoss()(y,y_target)\n",
        "\n",
        "        l2_loss = 0.0\n",
        "        for param in self.parameters():\n",
        "            data = param* param\n",
        "            l2_loss += data.sum()\n",
        "           \n",
        "        loss += 0.5* l2_regularization* l2_loss\n",
        "            \n",
        "        return loss\n",
        "    \n",
        "    \n",
        "    def update(self, lr):\n",
        "                \n",
        "        update = torch.optim.SGD( self.parameters(), lr=lr, momentum=0.9 )\n",
        "        \n",
        "        return update\n",
        "        \n",
        "        \n",
        "    def update_learning_rate(self, optimizer, lr):\n",
        "   \n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = lr\n",
        "\n",
        "        return optimizer\n",
        "\n",
        "    \n",
        "    def evaluation(self, y_predicted, test_l):\n",
        "    \n",
        "        _, class_predicted = torch.max(y_predicted.data, 1)\n",
        "        return 100.0* (class_predicted == test_l).sum()/ y_predicted.size(0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hzp6aP9YIM9T",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "14f822d9-a425-4a55-922b-70518beea634"
      },
      "source": [
        "# Delete existing network if exists\n",
        "try:\n",
        "    del net\n",
        "    print('Delete existing network\\n')\n",
        "except NameError:\n",
        "    print('No existing network to delete\\n')\n",
        "\n",
        "# network parameters\n",
        "D = train_data.shape[1]\n",
        "CL1_F = 32\n",
        "CL1_K = 25\n",
        "CL2_F = 64\n",
        "CL2_K = 25\n",
        "FC1_F = 512\n",
        "FC2_F = 10\n",
        "net_parameters = [D, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F]\n",
        "dropout_value = 0.5\n",
        "\n",
        "# instantiate the object net of the class \n",
        "net = Graph_ConvNet_LeNet5(net_parameters)\n",
        "if torch.cuda.is_available():\n",
        "    net.cuda()\n",
        "print(net)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "No existing network to delete\n",
            "\n",
            "Graph ConvNet: LeNet5\n",
            "nb of parameters= 2023818 \n",
            "\n",
            "Graph_ConvNet_LeNet5(\n",
            "  (cl1): Linear(in_features=25, out_features=32, bias=True)\n",
            "  (cl2): Linear(in_features=800, out_features=64, bias=True)\n",
            "  (fc1): Linear(in_features=3840, out_features=512, bias=True)\n",
            "  (fc2): Linear(in_features=512, out_features=10, bias=True)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VAHvOqwFIM9T"
      },
      "source": [
        "Good time, to check your network is working..."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iUBExiK7IM9U",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ecf6f722-db5a-45d8-990b-705aadb43b77"
      },
      "source": [
        "train_x, train_y = train_data[:5,:], train_labels[:5]\n",
        "train_x =  torch.FloatTensor(train_x).type(dtypeFloat)\n",
        "train_y = train_y.astype(np.int64)\n",
        "train_y = torch.LongTensor(train_y).type(dtypeLong) \n",
        "            \n",
        "# Forward \n",
        "y = net(train_x, dropout_value, L, lmax)\n",
        "print(y.shape)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([5, 10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "6pqSKZucIM9U",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6dbb9edd-7b30-4f9a-af78-00b819f91476"
      },
      "source": [
        "# Weights\n",
        "L_net = list(net.parameters())\n",
        "\n",
        "# learning parameters\n",
        "learning_rate = 0.05\n",
        "l2_regularization = 5e-4 \n",
        "batch_size = 100\n",
        "num_epochs = 3\n",
        "train_size = train_data.shape[0]\n",
        "nb_iter = int(num_epochs * train_size) // batch_size\n",
        "print('num_epochs=',num_epochs,', train_size=',train_size,', nb_iter=',nb_iter)\n",
        "\n",
        "# Optimizer\n",
        "global_lr = learning_rate\n",
        "global_step = 0\n",
        "decay = 0.95\n",
        "decay_steps = train_size\n",
        "lr = learning_rate\n",
        "optimizer = net.update(lr) \n",
        "\n",
        "# loop over epochs\n",
        "indices = collections.deque()\n",
        "for epoch in range(num_epochs):  # loop over the dataset multiple times\n",
        "\n",
        "    # reshuffle \n",
        "    indices.extend(np.random.permutation(train_size)) # rand permutation\n",
        "    \n",
        "    # reset time\n",
        "    t_start = time.time()\n",
        "    \n",
        "    # extract batches\n",
        "    running_loss = 0.0\n",
        "    running_accuray = 0\n",
        "    running_total = 0\n",
        "    while len(indices) >= batch_size:\n",
        "        \n",
        "        # extract batches\n",
        "        batch_idx = [indices.popleft() for i in range(batch_size)]\n",
        "        train_x, train_y = train_data[batch_idx,:], train_labels[batch_idx]\n",
        "        train_x =  torch.FloatTensor(train_x).type(dtypeFloat)\n",
        "        train_y = train_y.astype(np.int64)\n",
        "        train_y = torch.LongTensor(train_y).type(dtypeLong) \n",
        "            \n",
        "        # Forward \n",
        "        y = net(train_x, dropout_value, L, lmax)\n",
        "        loss = net.loss(y,train_y,l2_regularization) \n",
        "        loss_train = loss.detach().item()\n",
        "        # Accuracy\n",
        "        acc_train = net.evaluation(y,train_y.data)\n",
        "        # backward\n",
        "        loss.backward()\n",
        "        # Update \n",
        "        global_step += batch_size # to update learning rate\n",
        "        optimizer.step()\n",
        "        optimizer.zero_grad()\n",
        "        # loss, accuracy\n",
        "        running_loss += loss_train\n",
        "        running_accuray += acc_train\n",
        "        running_total += 1\n",
        "        # print        \n",
        "        if not running_total%100: # print every x mini-batches\n",
        "            print('epoch= %d, i= %4d, loss(batch)= %.4f, accuray(batch)= %.2f' % (epoch+1, running_total, loss_train, acc_train))\n",
        "          \n",
        "    # print \n",
        "    t_stop = time.time() - t_start\n",
        "    print('epoch= %d, loss(train)= %.3f, accuracy(train)= %.3f, time= %.3f, lr= %.5f' % \n",
        "          (epoch+1, running_loss/running_total, running_accuray/running_total, t_stop, lr))\n",
        " \n",
        "    # update learning rate \n",
        "    lr = global_lr * pow( decay , float(global_step// decay_steps) )\n",
        "    optimizer = net.update_learning_rate(optimizer, lr)\n",
        "    \n",
        "    \n",
        "    # Test set\n",
        "    with torch.no_grad():\n",
        "        running_accuray_test = 0\n",
        "        running_total_test = 0\n",
        "        indices_test = collections.deque()\n",
        "        indices_test.extend(range(test_data.shape[0]))\n",
        "        t_start_test = time.time()\n",
        "        while len(indices_test) >= batch_size:\n",
        "            batch_idx_test = [indices_test.popleft() for i in range(batch_size)]\n",
        "            test_x, test_y = test_data[batch_idx_test,:], test_labels[batch_idx_test]\n",
        "            test_x = torch.FloatTensor(test_x).type(dtypeFloat)\n",
        "            y = net(test_x, 0.0, L, lmax) \n",
        "            test_y = test_y.astype(np.int64)\n",
        "            test_y = torch.LongTensor(test_y).type(dtypeLong)\n",
        "            acc_test = net.evaluation(y,test_y.data)\n",
        "            running_accuray_test += acc_test\n",
        "            running_total_test += 1\n",
        "        t_stop_test = time.time() - t_start_test\n",
        "        print('  accuracy(test) = %.3f %%, time= %.3f' % (running_accuray_test / running_total_test, t_stop_test))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "num_epochs= 3 , train_size= 60000 , nb_iter= 1800\n",
            "epoch= 1, i=  100, loss(batch)= 0.7743, accuray(batch)= 78.00\n",
            "epoch= 1, i=  200, loss(batch)= 0.4548, accuray(batch)= 87.00\n",
            "epoch= 1, i=  300, loss(batch)= 0.3556, accuray(batch)= 92.00\n",
            "epoch= 1, i=  400, loss(batch)= 0.2642, accuray(batch)= 96.00\n",
            "epoch= 1, i=  500, loss(batch)= 0.3511, accuray(batch)= 96.00\n",
            "epoch= 1, i=  600, loss(batch)= 0.4196, accuray(batch)= 94.00\n",
            "epoch= 1, loss(train)= 0.567, accuracy(train)= 86.562, time= 79.040, lr= 0.05000\n",
            "  accuracy(test) = 97.050 %, time= 3.137\n",
            "epoch= 2, i=  100, loss(batch)= 0.1950, accuray(batch)= 98.00\n",
            "epoch= 2, i=  200, loss(batch)= 0.3069, accuray(batch)= 93.00\n",
            "epoch= 2, i=  300, loss(batch)= 0.2529, accuray(batch)= 98.00\n",
            "epoch= 2, i=  400, loss(batch)= 0.4034, accuray(batch)= 93.00\n",
            "epoch= 2, i=  500, loss(batch)= 0.1715, accuray(batch)= 99.00\n",
            "epoch= 2, i=  600, loss(batch)= 0.2504, accuray(batch)= 94.00\n",
            "epoch= 2, loss(train)= 0.254, accuracy(train)= 96.468, time= 79.270, lr= 0.04750\n",
            "  accuracy(test) = 97.970 %, time= 3.147\n",
            "epoch= 3, i=  100, loss(batch)= 0.2337, accuray(batch)= 97.00\n",
            "epoch= 3, i=  200, loss(batch)= 0.1680, accuray(batch)= 99.00\n",
            "epoch= 3, i=  300, loss(batch)= 0.2316, accuray(batch)= 98.00\n",
            "epoch= 3, i=  400, loss(batch)= 0.2472, accuray(batch)= 95.00\n",
            "epoch= 3, i=  500, loss(batch)= 0.3163, accuray(batch)= 92.00\n",
            "epoch= 3, i=  600, loss(batch)= 0.1318, accuray(batch)= 100.00\n",
            "epoch= 3, loss(train)= 0.210, accuracy(train)= 97.343, time= 79.303, lr= 0.04512\n",
            "  accuracy(test) = 98.050 %, time= 3.149\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WcMxnNIiXLyc"
      },
      "source": [
        "### <font color='red'>Question 3: In this code, each convolutional layer has a parameter K. What does it represent? What are the consequences of choosing a higher or lower value of K? </font> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1sYCHHfXxEX"
      },
      "source": [
        "The number K represents the number of polynomials in Chebyshev expansion of the kernel that is learned in the convolutional layer. (Since the Chebyshev polynomial of order $k$ has degree $k$, K is also the highest degree in the polynomial decomposition of the kernel.)  \r\n",
        "\r\n",
        "Therefore, the highest K is, the more the kernel can be expressive. For instance, in the limit $K \\rightarrow \\infty$, we could possibly retrieve Gaussian filters. Moreover, given that $dist_\\mathcal G​(i,j)>K \\implies(L^K)_{i,j} =0$, a higher K also gives the possibility to learn from points that are far away from each other.\r\n",
        "\r\n",
        "The disadvantage of using too large a K value is that it can become too computationaly heavy. It might as well encourage overfitting, and push the network towards local minima. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QqHiWQ8iIM9U"
      },
      "source": [
        "### <font color='red'>Question 4: Is it necessary to rescale the Laplacian (in the function `rescale_L`)? Try to remove it and explain what happens. </font> \n",
        "\n",
        "Hint: See Section 2.1 of [the paper](https://arxiv.org/pdf/1606.09375.pdf).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PsyW8kfGJhQb"
      },
      "source": [
        "When removing the rescaling, the loss becomes `nan` and the network cannot be trained.\r\n",
        "\r\n",
        "Recall that the Chebyshev polynomials form an orthogonal basis for $L^2([-1,1], dy/\\sqrt{1-y^2})$, therefore the input of the filter should lie in $[-1,1]$. If the coefficients of the Laplacians are not rescaled, then the coefficients of $\\Lambda$ can be bigger than $1$ (they lie in $[0, \\lambda_\\max ]$ ). In this case, the numerous multiplications by $L$ can lead to an overflow in the computations of the Chebyshev polynomials (which is exactly what happens here with the `nan` value). \r\n",
        "\r\n",
        "The rescaling $L = 2L/\\lambda_\\max - I$ is therefore necessary to ensure that $\\Lambda$ lies in $[-1,1]$. The code used above reduces this space to $[-1/4, 1/4]$, probably to allow the network to capture better the variations by staying away from values close to $1$."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Li12nmcbmhth"
      },
      "source": [
        "### <font color='red'>Question 5: Is it possible to modify the Laplacian to avoid the rescaling step?</font> \n",
        "\n",
        "Hint: Think about the eigenvalues of the Laplacian and how to normalize them."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fUEqQgHJmjUk"
      },
      "source": [
        "Instead of using the unnormalized Laplacian, one can use a normalized version of $L$, $L = I −D^{−1/2}WD^{-1/2}$. The eigenvalues of this symmetric version of the Laplacian lie in $[0,2]$, we can therefore replace the rescaling by $L = L - I$.\r\n",
        "\r\n",
        "Using this, we obtain $97.020\\%$ accuracy after 3 epochs.\r\n",
        "\r\n",
        "To obtain an accuracy score similar to the one obtained using rescaling with unnormalized Laplacians, i.e. $98.050\\%$, we can replace $L$ by $(L-I)*1/4$ to reduce the space to $[-1/4, 1/4]$ (as mentioned in Question 4)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cL4bpvypcGBx"
      },
      "source": [
        "### <font color='red'>Question 6: Is GCN like the one presented in video 2 a MGNN? </font> \n",
        "* (A) Yes for K=1\n",
        "* (B) Yes for any value of K\n",
        "* (C) No\n",
        "\n",
        "Explain your answer.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ktXFqpdgndJ"
      },
      "source": [
        "**(C) No**\r\n",
        "\r\n",
        "First, recall that a MGNN takes as input a discrete graph with nodes and features on the nodes $h_0$ and are defined inductively as: \r\n",
        "$$h_i^{l+1} = f\\left( h_i^l, [h_j^l]_{j\\sim i} \\right) $$\r\n",
        "That is, the features of a node at layer $l+1$ only depends on its own feature and the ones of its neighbors, at layer $l$. \r\n",
        "\r\n",
        "Therefore, for $K=1$, the behavior of a MGNN layer and a graph convolutional layer is the same. Indeed for $K=1$, the kernels are K=1-localized, meaning that they only consider the nodes' direct neighbors, which is precisely what the MGNN does. \r\n",
        "\r\n",
        "However, one key component that is in the GCN and that is absent from the MGNN presented in the [course](https://dataflowr.github.io/slides/deep_graph_3.html) are the pooling and graph-coarsening steps, that allow to reduce the dimension of the graph as we go through the architecture. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BjwfS7bHSTpd"
      },
      "source": [
        "### <font color='red'> Question 7: In which cases do you expect: </font> \n",
        "* a Graph CNN to work better than a CNN?\n",
        "* a CNN to work better than a Graph CNN?\n",
        "\n",
        "For the MNIST classification problem, is there an advantage in using a Graph CNN instead of CNN ? Explain. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "62sZdTlhTb9s"
      },
      "source": [
        "- We can expect a Graph CNN to work better on problem where data is directly represented as graphs (molecules, proteins, genes, social networks, etc.). In these types of data, there usually does not exist direct ways to use CNNs, or ways to use them without skipping a huge part of the information given by the structure of the graph. (Yet, some works tried to do it, such as https://arxiv.org/pdf/1708.02218.pdf).\r\n",
        "\r\n",
        "- On the other hand, we can expect a CNN to work better on 2D and 3D images where information is spatial and where it is disposed in a grid-structure. One task where CNNs shine is the recognition/ classification one, precisely because local euclidian neighborhoods of pixels contain enough information to determine the class of an object. The CNNs were originally designed for this type of data, and this type of task. \r\n",
        "\r\n",
        "For the MNIST classification problem, it seems that there is no real advantage in using Graph CNN. The original paper argues that one cause of it is the isotropic (direction invariant) nature of GCNN. This is not an advantage in the MNIST classification problem since the characters are already in the good orientation. However, for other tasks where objects could have different orientations, Graph CNN, by nature, should be able to learn rotation invariant features. Eventually, even if there is no improvement using a GCNN, the results are very close to the ones obtained with a CNN, meaning that it works almost as well on images. It allows the authors to check that their method is coherent with what it is inspired from. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l3UqvPQ74rys"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}