Multi-Layer-Neural-Networks / develop_neural_network.ipynb
develop_neural_network.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implement a Neural Network\n",
    "\n",
    "This notebook contains useful information and testing code to help you to develop a neural network by implementing the forward pass and backpropagation algorithm in the `models/neural_net.py` file. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from models.neural_net import NeuralNetwork\n",
    "\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0)  # set default size of plots\n",
    "\n",
    "# For auto-reloading external modules\n",
    "# See http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def rel_error(x, y):\n",
    "    \"\"\"Returns relative error\"\"\"\n",
    "    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will implement your network in the class `NeuralNetwork` inside the file `models/neural_net.py` to represent instances of the network. The network parameters are stored in the instance variable `self.params` where keys are string parameter names and values are numpy arrays.\n",
    "\n",
    "The cell below initializes a toy dataset and corresponding model which will allow you to check your forward and backward pass by using a numeric gradient check."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a small net and some toy data to check your implementations.\n",
    "# Note that we set the random seed for repeatable experiments.\n",
    "\n",
    "input_size = 4\n",
    "hidden_size = 5\n",
    "num_classes = 3\n",
    "num_inputs = 5\n",
    "\n",
    "\n",
    "def init_toy_model(num_layers):\n",
    "    np.random.seed(0)\n",
    "    hidden_sizes = [hidden_size] * (num_layers - 1)\n",
    "    return NeuralNetwork(input_size, hidden_sizes, num_classes, num_layers)\n",
    "\n",
    "def init_toy_data():\n",
    "    np.random.seed(0)\n",
    "    X = 10 * np.random.randn(num_inputs, input_size)\n",
    "#    y = np.array([0, 1, 2, 2, 1]) # (C,)\n",
    "    y = np.random.randint(num_classes, size=num_inputs)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# net = init_toy_model(2)\n",
    "# X, y = init_toy_data()\n",
    "# print(X)\n",
    "# scores=net.forward(X)\n",
    "# print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implement forward and backward pass\n",
    "\n",
    "The first thing you will do is implement the forward pass of your neural network along with the loss calculation. The forward pass should be implemented in the `forward` function. You can use helper functions like `linear`, `relu`, and `softmax` to help organize your code.\n",
    "\n",
    "Next, you will implement the backward pass using the backpropagation algorithm. Backpropagation will compute the gradient of the loss with respect to the model parameters `W1`, `b1`, ... etc. Use a softmax fuction with cross entropy loss for loss calcuation. Fill in the code blocks in `NeuralNetwork.backward`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the network\n",
    "To train the network we will use stochastic gradient descent (SGD), similar to the SVM and Softmax classifiers you trained. This should be similar to the training procedure you used for the SVM and Softmax classifiers.\n",
    "\n",
    "Once you have implemented SGD, run the code below to train a two-layer network on toy data. You should achieve a training loss less than 0.2 using a two-layer network with relu activation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "input_size = 4\n",
    "hidden_size = 5\n",
    "num_classes = 3\n",
    "num_inputs = 5\n",
    "\n",
    "\n",
    "epochs = 500#100\n",
    "batch_size = 1\n",
    "learning_rate = 1e-2\n",
    "learning_rate_decay = 0.95\n",
    "regularization =5e-6#0.09 #\n",
    "\n",
    "# Initialize a new neural network model\n",
    "net = init_toy_model(3)\n",
    "\n",
    "X, y= init_toy_data()\n",
    "# print(X)\n",
    "#print(X.shape, y.shape)\n",
    "# Variables to store performance for each epoch\n",
    "train_loss = np.zeros(epochs)\n",
    "train_accuracy = np.zeros(epochs)\n",
    "\n",
    "# For each epoch...\n",
    "for epoch in range(epochs): \n",
    "    scores=net.forward(X)\n",
    "#     print(scores)\n",
    "#     print()\n",
    "    y_pred = np.argmax(scores, axis=1)\n",
    "    #print(\"y_pred\",y_pred)\n",
    "    #print(\"y_actual\",y)\n",
    "    train_accuracy[epoch] = (np.sum(y==y_pred)/len(y))\n",
    "    \n",
    "    # Run the forward pass of the model to get a prediction and compute the accuracy\n",
    "    # Run the backward pass of the model to update the weights and compute the loss\n",
    "    \n",
    "    train_loss[epoch],_=net.backward(X,y,\"SGD\",learning_rate,regularization)  \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.71160664 1.19447489 0.84228145 0.78094819 0.75656102 0.73364339\n",
      " 0.71330759 0.69511702 0.67749643 0.66096199 0.64567992 0.63087822\n",
      " 0.61692832 0.60374883 0.59108207 0.57909891 0.56760112 0.55666277\n",
      " 0.54617023 0.53613973 0.52653433 0.5173223  0.50847589 0.4999706\n",
      " 0.49178469 0.48389866 0.47629497 0.46895771 0.46187239 0.45502573\n",
      " 0.45030083 0.44900411 0.43833874 0.43320147 0.43332431 0.42295032\n",
      " 0.41883799 0.4183316  0.40854689 0.40595836 0.40412592 0.39505082\n",
      " 0.39405146 0.39076544 0.38376585 0.38404575 0.37530196 0.37458112\n",
      " 0.37126373 0.36539509 0.36511662 0.35768095 0.35873265 0.35096842\n",
      " 0.35001708 0.34721579 0.34199963 0.34177059 0.33520272 0.33610852\n",
      " 0.32942698 0.32837909 0.32591951 0.32123385 0.32107859 0.31526243\n",
      " 0.31462537 0.31190532 0.30794384 0.30744494 0.3022725  0.30177031\n",
      " 0.29904386 0.29563123 0.29496148 0.29035119 0.29063026 0.28599801\n",
      " 0.28482759 0.28291855 0.27924855 0.27911989 0.27495904 0.27380246\n",
      " 0.27197447 0.26847148 0.26833271 0.26458264 0.26310978 0.26161926\n",
      " 0.25833508 0.25724321 0.25540794 0.25227064 0.25139375 0.24939934\n",
      " 0.24642576 0.24602229 0.24282877 0.24117025 0.23990825 0.23701928\n",
      " 0.23555711 0.23410722 0.23136592 0.2298314  0.22843801 0.22575736\n",
      " 0.22417577 0.22268554 0.21758607 0.21280212 0.20868023 0.20503003\n",
      " 0.20022854 0.19639372 0.19311524 0.18866058 0.18775713 0.18491384\n",
      " 0.18342773 0.18187414 0.18066043 0.17717459 0.17735266 0.17314705\n",
      " 0.17257691 0.16991321 0.16917376 0.16751736 0.16539878 0.16286308\n",
      " 0.16258098 0.16059502 0.15843273 0.15778541 0.15485309 0.1538432\n",
      " 0.15166707 0.15082267 0.14909069 0.14860329 0.14551451 0.14598558\n",
      " 0.14242643 0.14220086 0.14010415 0.13942705 0.13812383 0.13615243\n",
      " 0.13573227 0.13320311 0.13327858 0.13049375 0.12999026 0.12795275\n",
      " 0.1278469  0.12619924 0.12492655 0.12408978 0.12233023 0.1211392\n",
      " 0.11985643 0.11895319 0.11791898 0.11717066 0.11530556 0.11520673\n",
      " 0.1129439  0.11252972 0.11096721 0.11026934 0.10913037 0.10843899\n",
      " 0.10749774 0.1061061  0.10576177 0.10394068 0.10395165 0.10194578\n",
      " 0.10155984 0.10025005 0.09954166 0.09864124 0.09791315 0.09718039\n",
      " 0.09587541 0.09565518 0.09396719 0.09406441 0.09237628 0.09182299\n",
      " 0.09090766 0.09002835 0.08946549 0.08829605 0.08799749 0.08692431\n",
      " 0.08676674 0.08537622 0.08484571 0.08437521 0.08323046 0.08300994\n",
      " 0.08168249 0.0817329  0.0801343  0.08045857 0.07888134 0.07861555\n",
      " 0.07783598 0.07725441 0.07677213 0.0757353  0.07567445 0.07439185\n",
      " 0.07415899 0.07334902 0.0727537  0.07225594 0.07163538 0.07131159\n",
      " 0.07027547 0.07027079 0.06919657 0.06880791 0.06825643 0.0675477\n",
      " 0.06724823 0.06637532 0.06630784 0.06543366 0.06518618 0.06468958\n",
      " 0.0639827  0.06380945 0.06286491 0.0628989  0.06184429 0.06176501\n",
      " 0.06112313 0.06074493 0.06035121 0.05965843 0.05958164 0.05864816\n",
      " 0.05855438 0.05793004 0.05764477 0.05725765 0.05661996 0.05651629\n",
      " 0.055691   0.05560411 0.05502279 0.05474995 0.05436799 0.0538025\n",
      " 0.05369292 0.05294868 0.05284004 0.05233093 0.05206347 0.05170763\n",
      " 0.05118333 0.05109066 0.05039097 0.05029263 0.04980968 0.04955047\n",
      " 0.04925585 0.04874486 0.04868978 0.0480162  0.04792206 0.04749238\n",
      " 0.04721412 0.04697538 0.04646505 0.04644904 0.04581782 0.04578462\n",
      " 0.04535969 0.04501814 0.04486669 0.04434837 0.04431524 0.04386036\n",
      " 0.04367675 0.0434311  0.04312391 0.04300414 0.0424858  0.04256101\n",
      " 0.04202048 0.04195246 0.04167967 0.04129284 0.04121257 0.04080914\n",
      " 0.04069414 0.04048859 0.0401542  0.04000817 0.03964914 0.03964369\n",
      " 0.0391734  0.03906593 0.03883351 0.03855278 0.03851353 0.03804701\n",
      " 0.03808865 0.03775564 0.03751987 0.03735921 0.03702998 0.03701761\n",
      " 0.03665555 0.03655118 0.03631673 0.03610951 0.03600849 0.0356354\n",
      " 0.03558937 0.03534576 0.03515698 0.03504498 0.03469103 0.0346779\n",
      " 0.03436212 0.03426892 0.0341087  0.03383241 0.03375238 0.03345054\n",
      " 0.03338849 0.03320444 0.03297961 0.03287339 0.03260039 0.0326162\n",
      " 0.03227467 0.03217288 0.03202543 0.03178983 0.03178685 0.03145104\n",
      " 0.03144053 0.03122914 0.03101787 0.03098392 0.03067071 0.03068492\n",
      " 0.03045522 0.03027999 0.03017619 0.02994562 0.0299463  0.02966546\n",
      " 0.02958226 0.02944882 0.02924317 0.02923175 0.02894586 0.02890569\n",
      " 0.02873878 0.0285691  0.02852878 0.02826264 0.0282631  0.02807509\n",
      " 0.02791323 0.02783175 0.02762218 0.02762726 0.02738637 0.02729672\n",
      " 0.02718792 0.02700747 0.02699189 0.02675373 0.02671837 0.02657672\n",
      " 0.02639603 0.02635562 0.02615844 0.02611881 0.02598976 0.02582067\n",
      " 0.02577438 0.02558349 0.02553754 0.02539603 0.02529313 0.02521509\n",
      " 0.02501787 0.02499234 0.024848   0.02472788 0.02468915 0.02447551\n",
      " 0.02448571 0.02432407 0.02420642 0.02413256 0.02397847 0.02397496\n",
      " 0.02378482 0.02371359 0.02362536 0.02348076 0.02347308 0.02328252\n",
      " 0.02323491 0.02313427 0.02299925 0.02298831 0.02280785 0.02276762\n",
      " 0.02267545 0.02253203 0.02250054 0.02234687 0.02230793 0.02219853\n",
      " 0.02209658 0.02206339 0.02189541 0.02186953 0.02175927 0.02165034\n",
      " 0.0216277  0.02146417 0.02145382 0.02133932 0.02122918 0.02118919\n",
      " 0.02104852 0.02102386 0.02091383 0.02082734 0.0207986  0.02063825\n",
      " 0.02062764 0.02052058 0.02042788 0.02040192 0.02024949 0.02025056\n",
      " 0.02013704 0.02004029 0.02000242 0.01987058 0.01985973 0.01975822\n",
      " 0.01968476 0.01963996 0.01950221 0.01950536 0.01939783 0.01931618\n",
      " 0.01928388 0.01915085 0.01913957 0.01904754 0.01896039 0.01892856\n",
      " 0.01880364 0.0188177  0.01868898 0.01863843 0.01859184 0.01846655\n",
      " 0.01846752 0.01837243 0.01829918 0.01825771 0.01814767 0.01815478\n",
      " 0.01803396 0.01799435 0.0179434  0.01782748 0.01782941 0.01772537\n",
      " 0.01767203 0.01762634 0.01753253 0.01752936 0.01741993 0.01738116\n",
      " 0.01733158 0.01722837 0.01723392 0.01712839 0.01708069 0.01703938\n",
      " 0.01693879 0.01693194]\n",
      "[0.4 0.4 0.4 0.6 0.6 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8\n",
      " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8\n",
      " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8\n",
      " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8\n",
      " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 1.  0.8 1.  0.8 1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1. ]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot the loss function and train / validation accuracies\n",
    "print(train_loss)\n",
    "print(train_accuracy)\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(train_loss)\n",
    "plt.title('Loss history')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Loss')\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.plot(train_accuracy)\n",
    "plt.title('Classification accuracy history')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Classification accuracy')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "mean_image = np.mean(X, axis=0)\n",
    "X -= mean_image\n",
    "#+++++++++++++++++++++++++++++++++++++++++++++\n",
    "def set_zero(dictio):\n",
    "    new_dictio={}\n",
    "    for wx in dictio:\n",
    "        new_dictio[wx]=np.zeros(dictio[wx].shape)\n",
    "    return new_dictio\n",
    "def dict_x_ele(dict1, ele):\n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=np.multiply(dict1[wx],ele)\n",
    "    return new_dictio\n",
    "def dict_by_ele(dict1, ele): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=np.divide(dict1[wx],ele)\n",
    "    return new_dictio    \n",
    "def dict_sqr(dict1): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=np.square(dict1[wx])\n",
    "    return new_dictio        \n",
    "def dict_sqrt(dict1): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=np.sqrt(dict1[wx])\n",
    "    return new_dictio   \n",
    "def dict_add_ele(dict1,ele): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=dict1[wx]+ele\n",
    "    return new_dictio     \n",
    "def dict_by_dict(dict1,dict2): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=dict1[wx]/dict2[wx]\n",
    "    return new_dictio \n",
    "def dict_minus_dict(dict1,dict2): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=dict1[wx]-dict2[wx]\n",
    "    return new_dictio \n",
    "def dict_add_dict(dict1,dict2): \n",
    "    new_dictio={}\n",
    "    for wx in dict1:\n",
    "        new_dictio[wx]=dict1[wx]+dict2[wx]\n",
    "    return new_dictio \n",
    "import copy\n",
    "epochs = 100#100\n",
    "batch_size = 1\n",
    "learning_rate = 0.05#1e-3\n",
    "learning_rate_decay = 0.95\n",
    "regularization =0.025 #5e-6 # #\n",
    "betas=(0.9, 0.999)\n",
    "epsilon=1e-8\n",
    "# Initialize a new neural network model\n",
    "net = init_toy_model(2)\n",
    "\n",
    "X, y= init_toy_data()\n",
    "# print(X)\n",
    "#print(X.shape, y.shape)\n",
    "# Variables to store performance for each epoch\n",
    "train_loss = np.zeros(epochs)\n",
    "train_accuracy = np.zeros(epochs)\n",
    "w = copy.copy(net.params)\n",
    "m = set_zero(w) #creates a new dictionary which has the same size/elemets as the one passesed all set to zero\n",
    "v = set_zero(w)\n",
    "w_prev=copy.deepcopy(net.params)\n",
    "v_prev=set_zero(w)\n",
    "m_prev=set_zero(w)\n",
    "train_loss_adam = np.zeros(epochs)\n",
    "train_accuracy_adam = np.zeros(epochs)\n",
    "# For each epoch...\n",
    "for epoch in range(epochs):\n",
    "        \n",
    "        scores=net.forward(X)\n",
    "        y_pred=np.argmax(scores, axis=1)\n",
    "        train_accuracy_adam[epoch] += (np.sum(y==y_pred)/len(y))\n",
    "        \n",
    "        # Run the backward pass of the model to update the weights and compute the loss\n",
    "        loss, gradients=net.backward(X,y,learning_rate,regularization)\n",
    "        \n",
    "        m=dict_add_dict(dict_x_ele(m_prev,betas[0]),dict_x_ele(gradients,(1-betas[0])))\n",
    "        #m=betas[0]*m_prev+(1-betas[0])*gradient_loss(pnts[data,3], pnts[data,:3], w_prev)\n",
    "        \n",
    "        v=dict_add_dict(dict_x_ele(v_prev,betas[1]),dict_x_ele(dict_sqr(gradients),(1-betas[1])))\n",
    "        \n",
    "        #v=betas[1]*v_prev+(1-betas[1])*(gradient_loss(pnts[data,3], pnts[data,:3], w_prev))**2\n",
    "        \n",
    "        m_new=dict_by_ele(m,(1-betas[0]))\n",
    "        v_new=dict_by_ele(v,(1-betas[1]))\n",
    "#         m_new=m/(1-betas[0])\n",
    "#         v_new=v/(1-betas[1])\n",
    "        \n",
    "        w=dict_minus_dict(w_prev,dict_by_dict(dict_x_ele(m_new,learning_rate),(dict_add_ele(dict_sqrt(v_new),epsilon))))\n",
    "        net.params=copy.copy(w)\n",
    "        #w=w_prev-(alpha*m_new/(np.sqrt(v_new)+epsilon))\n",
    "        \n",
    "        w_prev=copy.copy(w)\n",
    "        v_prev=copy.copy(v)\n",
    "        m_prev=copy.copy(m)\n",
    "\n",
    "        train_loss_adam[epoch]+=loss    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.26322338e+01 8.36293605e+00 4.80967297e+00 3.40966175e+00\n",
      " 3.80155087e+00 3.45677951e+00 2.80491091e+00 3.09474568e+00\n",
      " 2.78714955e+00 2.00547310e+00 1.38990564e+00 1.01073193e+00\n",
      " 7.25033433e-01 5.86671959e-01 1.01253767e+00 6.70870100e-01\n",
      " 2.04300719e-01 2.15608821e-01 2.88228335e-01 3.51809407e-01\n",
      " 3.82433019e-01 3.73283935e-01 3.26531080e-01 2.57288597e-01\n",
      " 1.92411810e-01 1.43350875e-01 1.06511763e-01 8.07162498e-02\n",
      " 6.32776410e-02 5.17429257e-02 4.46587300e-02 4.16690935e-02\n",
      " 4.30930418e-02 4.85284204e-02 5.45129069e-02 5.50725508e-02\n",
      " 4.80283644e-02 3.76379876e-02 2.87017281e-02 2.27005806e-02\n",
      " 1.90647967e-02 1.69024116e-02 1.55762420e-02 1.47115245e-02\n",
      " 1.40980915e-02 1.36203765e-02 1.32130713e-02 1.28385424e-02\n",
      " 1.24750536e-02 1.21104823e-02 1.17387970e-02 1.13579389e-02\n",
      " 1.09684470e-02 1.05725025e-02 1.01732372e-02 9.77422413e-03\n",
      " 9.37910757e-03 8.99134470e-03 8.61403648e-03 8.24983010e-03\n",
      " 7.90087617e-03 7.56882623e-03 7.25485770e-03 6.95971621e-03\n",
      " 6.68376683e-03 6.42704852e-03 6.18932730e-03 5.97014568e-03\n",
      " 5.76886681e-03 5.58471291e-03 5.41679802e-03 5.26445779e-03\n",
      " 5.12689165e-03 5.00252305e-03 4.89025735e-03 4.78899682e-03\n",
      " 4.69765368e-03 4.61516222e-03 4.54049021e-03 4.47264963e-03\n",
      " 4.41185262e-03 4.35633517e-03 4.30514075e-03 4.25752107e-03\n",
      " 4.21280870e-03 4.17042024e-03 4.12985721e-03 4.09070444e-03\n",
      " 4.05262644e-03 4.01536166e-03 3.97871531e-03 3.94255107e-03\n",
      " 3.90678211e-03 3.87136209e-03 3.83627628e-03 3.80153340e-03\n",
      " 3.76715827e-03 3.73318549e-03 3.69965416e-03 3.66660383e-03]\n",
      "[0.2 0.2 0.4 0.6 0.4 0.6 0.8 0.6 0.4 0.6 0.8 0.8 0.8 0.8 0.4 0.6 1.  1.\n",
      " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.\n",
      " 1.  1.  1.  1.  1.  1.  1.  1.  1.  1. ]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x576 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot the loss function and train / validation accuracies\n",
    "print(train_loss_adam)\n",
    "print(train_accuracy_adam)\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(train_loss_adam)\n",
    "plt.title('Loss history')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Loss')\n",
    "\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.plot(train_accuracy_adam)\n",
    "plt.title('Classification accuracy history')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Classification accuracy')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}