Recurrent-Neural-Networks / generation.ipynb
generation.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tNSNsNcfz126"
   },
   "source": [
    "# Generating Text with an RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6764,
     "status": "ok",
     "timestamp": 1606238124957,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "ryGrBgefz126",
    "outputId": "d6e8727a-c3db-4e13-d59a-e1a6f7d2eb9f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting Unidecode\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d0/42/d9edfed04228bacea2d824904cae367ee9efd05e6cce7ceaaedd0b0ad964/Unidecode-1.1.1-py2.py3-none-any.whl (238kB)\n",
      "\r",
      "\u001b[K     |█▍                              | 10kB 28.3MB/s eta 0:00:01\r",
      "\u001b[K     |██▊                             | 20kB 31.2MB/s eta 0:00:01\r",
      "\u001b[K     |████▏                           | 30kB 20.3MB/s eta 0:00:01\r",
      "\u001b[K     |█████▌                          | 40kB 24.0MB/s eta 0:00:01\r",
      "\u001b[K     |██████▉                         | 51kB 23.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████▎                       | 61kB 25.8MB/s eta 0:00:01\r",
      "\u001b[K     |█████████▋                      | 71kB 17.2MB/s eta 0:00:01\r",
      "\u001b[K     |███████████                     | 81kB 17.7MB/s eta 0:00:01\r",
      "\u001b[K     |████████████▍                   | 92kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |█████████████▊                  | 102kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |███████████████▏                | 112kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████████████▌               | 122kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |█████████████████▉              | 133kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |███████████████████▎            | 143kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████████████████▋           | 153kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |██████████████████████          | 163kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |███████████████████████▍        | 174kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████████████████████▊       | 184kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |██████████████████████████▏     | 194kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |███████████████████████████▌    | 204kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████████████████████████▉   | 215kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |██████████████████████████████▎ | 225kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |███████████████████████████████▋| 235kB 17.1MB/s eta 0:00:01\r",
      "\u001b[K     |████████████████████████████████| 245kB 17.1MB/s \n",
      "\u001b[?25hInstalling collected packages: Unidecode\n",
      "Successfully installed Unidecode-1.1.1\n"
     ]
    }
   ],
   "source": [
    "!pip install Unidecode\n",
    "\n",
    "import unidecode\n",
    "import string\n",
    "import random\n",
    "import re\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 23974,
     "status": "ok",
     "timestamp": 1606238107555,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "s9CBbG-Lz8jO",
    "outputId": "810c3c23-ffca-4c82-89b9-33a01680077c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /content/gdrive\n"
     ]
    }
   ],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 625,
     "status": "ok",
     "timestamp": 1606238115471,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "tWZZ0LtOz-q_",
    "outputId": "a4f42e28-a027-4c73-b358-b9c92e73f238"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/gdrive/My Drive/DL_stuff/assignment_4_part_2\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.chdir(\"gdrive/My Drive/DL_stuff/assignment_4_part_2\")\n",
    "#os.chdir(\"./assignment1\")\n",
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 1134,
     "status": "ok",
     "timestamp": 1606238129397,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "UGF2OePpz127"
   },
   "outputs": [],
   "source": [
    "from rnn.model import RNN\n",
    "from rnn.helpers import time_since\n",
    "from rnn.generate import generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 284,
     "status": "ok",
     "timestamp": 1606238131352,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "ivnXVh-9z128",
    "outputId": "cac8379f-a300-412e-b92b-490b5408a91c",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") \n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4Yn2thRqz128"
   },
   "source": [
    "## Data Processing\n",
    "\n",
    "The file we are using is a plain text file. We turn any potential unicode characters into plain ASCII by using the `unidecode` package (which you can install via `pip` or `conda`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 917,
     "status": "ok",
     "timestamp": 1606238135558,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "DyoO0kR7z129",
    "outputId": "eb07f773-e8dc-4d96-c8fb-e7b729643070"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file_len = 4573338\n",
      "train len:  4116004\n",
      "test len:  457334\n"
     ]
    }
   ],
   "source": [
    "all_characters = string.printable\n",
    "#print(all_characters)\n",
    "n_characters = len(all_characters)\n",
    "#print(n_characters)\n",
    "\n",
    "file_path = './shakespeare.txt'\n",
    "file = unidecode.unidecode(open(file_path).read())\n",
    "file_len = len(file)\n",
    "print('file_len =', file_len)\n",
    "\n",
    "# we will leave the last 1/10th of text as test\n",
    "split = int(0.9*file_len)\n",
    "train_text = file[:split]\n",
    "test_text = file[split:]\n",
    "\n",
    "print('train len: ', len(train_text))\n",
    "print('test len: ', len(test_text))\n",
    "#print(train_text[0:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 296,
     "status": "ok",
     "timestamp": 1606238137359,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "zOh2NDhAz129",
    "outputId": "9ec9b027-0988-47fd-9e52-aa2477220488"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " sharpens the stomach. Come,\n",
      "Leonine, take her by the arm, walk with her.\n",
      "\n",
      "MARINA:\n",
      "No, I pray you;\n",
      "I'll not bereave you of your servant.\n",
      "\n",
      "DIONYZA:\n",
      "Come, come;\n",
      "I love the king your father, and yourself,\n"
     ]
    }
   ],
   "source": [
    "chunk_len = 200\n",
    "\n",
    "def random_chunk(text):\n",
    "    start_index = random.randint(0, len(text) - chunk_len)\n",
    "    end_index = start_index + chunk_len + 1\n",
    "    return text[start_index:end_index]\n",
    "\n",
    "print(random_chunk(train_text))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N8_75JkOz129"
   },
   "source": [
    "### Input and Target data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kBNSF4goz12-"
   },
   "source": [
    "To make training samples out of the large string of text data, we will be splitting the text into chunks.\n",
    "\n",
    "Each chunk will be turned into a tensor, specifically a `LongTensor` (used for integer values), by looping through the characters of the string and looking up the index of each character in `all_characters`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 290,
     "status": "ok",
     "timestamp": 1606238146654,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "Zq-7g92gz12-"
   },
   "outputs": [],
   "source": [
    "# Turn string into list of longs\n",
    "def char_tensor(string):\n",
    "    tensor = torch.zeros(len(string), requires_grad=True).long()\n",
    "    for c in range(len(string)):\n",
    "        #print(\"this is string c\",string[c])\n",
    "        tensor[c] = all_characters.index(string[c])\n",
    "    #print(tensor)\n",
    "    return tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ye0zfwoxz12-"
   },
   "source": [
    "The following function loads a batch of input and target tensors for training. Each sample comes from a random chunk of text. A sample input will consist of all characters *except the last*, while the target wil contain all characters *following the first*. For example: if random_chunk='abc', then input='ab' and target='bc'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 1004,
     "status": "ok",
     "timestamp": 1606238148933,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "HYf9fU5uz12_"
   },
   "outputs": [],
   "source": [
    "def load_random_batch(text, chunk_len, batch_size):\n",
    "    input_data = torch.zeros(batch_size, chunk_len).long().to(device)\n",
    "    target = torch.zeros(batch_size, chunk_len).long().to(device)\n",
    "    for i in range(batch_size):\n",
    "        start_index = random.randint(0, len(text) - chunk_len - 1)\n",
    "        end_index = start_index + chunk_len + 1\n",
    "        chunk = text[start_index:end_index]\n",
    "        #print(chunk)\n",
    "        input_data[i] = char_tensor(chunk[:-1])\n",
    "        target[i] = char_tensor(chunk[1:])\n",
    "    #print(input_data.size(),target.size())\n",
    "    #print(input_data[:,1])\n",
    "    return input_data, target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ueM_L0t9z12_"
   },
   "source": [
    "# Implement model\n",
    "\n",
    "Your RNN model will take as input the character for step $t_{-1}$ and output a prediction for the next character $t$. The model should consiste of three layers - a linear layer that encodes the input character into an embedded state, an RNN layer (which may itself have multiple layers) that operates on that embedded state and a hidden state, and a decoder layer that outputs the predicted character scores distribution.\n",
    "\n",
    "\n",
    "You must implement your model in the `rnn/model.py` file. You should use a `nn.Embedding` object for the encoding layer, a RNN model like `nn.RNN` or `nn.LSTM`, and a `nn.Linear` layer for the final a predicted character score decoding layer.\n",
    "\n",
    "\n",
    "**TODO:** Implement the model in RNN `rnn/model.py`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7ghSMnj2z12_"
   },
   "source": [
    "# Evaluating\n",
    "\n",
    "To evaluate the network we will feed one character at a time, use the outputs of the network as a probability distribution for the next character, and repeat. To start generation we pass a priming string to start building up the hidden state, from which we then generate one character at a time.\n",
    "\n",
    "\n",
    "Note that in the `evaluate` function, every time a prediction is made the outputs are divided by the \"temperature\" argument. Higher temperature values make actions more equally likely giving more \"random\" outputs. Lower temperature values (less than 1) high likelihood options contribute more. A temperature near 0 outputs only the most likely outputs.\n",
    "\n",
    "You may check different temperature values yourself, but we have provided a default which should work well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 213,
     "status": "ok",
     "timestamp": 1606238151708,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "mKSVxrP9z12_"
   },
   "outputs": [],
   "source": [
    "def evaluate(rnn, prime_str='A', predict_len=100, temperature=0.8):\n",
    "    hidden = rnn.init_hidden(1, device=device)\n",
    "    prime_input = char_tensor(prime_str)\n",
    "    predicted = prime_str\n",
    "\n",
    "    # Use priming string to \"build up\" hidden state\n",
    "    for p in range(len(prime_str) - 1):\n",
    "        _, hidden = rnn(prime_input[p].unsqueeze(0).to(device), hidden)\n",
    "        #print(\"hidden in evaluate\", hidden)\n",
    "    inp = prime_input[-1]\n",
    "    \n",
    "    for p in range(predict_len):\n",
    "        output, hidden = rnn(inp.unsqueeze(0).to(device), hidden)\n",
    "        #print(\"output in evaluate\",output )\n",
    "        # Sample from the network as a multinomial distribution\n",
    "        output_dist = output.data.view(-1).div(temperature).exp()\n",
    "        top_i = torch.multinomial(output_dist, 1)[0]\n",
    "        \n",
    "        # Add predicted character to string and use as next input\n",
    "        predicted_char = all_characters[top_i]\n",
    "        predicted += predicted_char\n",
    "        inp = char_tensor(predicted_char)\n",
    "\n",
    "    return predicted"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N0FTD7fGz12_"
   },
   "source": [
    "# Train RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 217,
     "status": "ok",
     "timestamp": 1606238153997,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "YFfxLtPdz12_"
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "n_epochs = 750\n",
    "hidden_size = 200\n",
    "n_layers = 1\n",
    "learning_rate = 0.01\n",
    "model_type = 'gru'\n",
    "# model_type = 'rnn'\n",
    "# model_type = 'lstm'\n",
    "print_every = 50\n",
    "plot_every = 50\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 489,
     "status": "ok",
     "timestamp": 1606238156331,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "BwD65P4lz12_"
   },
   "outputs": [],
   "source": [
    "def eval_test(rnn, inp, target):\n",
    "    with torch.no_grad():\n",
    "        hidden = rnn.init_hidden(batch_size, device=device)\n",
    "        loss = 0\n",
    "        for c in range(chunk_len):\n",
    "            output, hidden = rnn(inp[:,c], hidden)\n",
    "            loss += criterion(output.view(batch_size, -1), target[:,c])\n",
    "            \n",
    "    return loss.data.item() / chunk_len"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Vie3D97sz12_"
   },
   "source": [
    "### Train function\n",
    "\n",
    "**TODO**: Fill in the train function. You should initialize a hidden layer representation using your RNN's `init_hidden` function, set the model gradients to zero, and loop over each time step (character) in the input tensor. For each time step compute the output of the of the RNN and compute the loss over the output and the corresponding ground truth time step in `target`. The loss should be averaged over all time steps. Lastly, call backward on the averaged loss and take an optimizer step.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 208,
     "status": "ok",
     "timestamp": 1606238160465,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "3o4Vbaxjz12_"
   },
   "outputs": [],
   "source": [
    "def train(rnn, input, target, optimizer, criterion):\n",
    "    \"\"\"\n",
    "    Inputs:\n",
    "    - rnn: model\n",
    "    - input: input character data tensor of shape (batch_size, chunk_len)\n",
    "    - target: target character data tensor of shape (batch_size, chunk_len)\n",
    "    - optimizer: rnn model optimizer\n",
    "    - criterion: loss function\n",
    "    \n",
    "    Returns:\n",
    "    - loss: computed loss value as python float\n",
    "    \"\"\"\n",
    "    loss =0\n",
    "    \n",
    "    ####################################\n",
    "    #          YOUR CODE HERE          #\n",
    "    ####################################\n",
    "    batch_size=input.size()[0] \n",
    "    chunk_size=input.size()[1] \n",
    "    #print(chunk_size)\n",
    "    rnn_hidden = rnn.init_hidden(batch_size, device=device) # initialize a hidden layer representation using your RNN's init_hidden function\n",
    "    rnn.zero_grad()  # set the model gradients to zero\n",
    "    \n",
    "    #this should be similar to the eval_test(rnn, inp, target) function above.\n",
    "    for x in range(chunk_size): # loop over each time step (character) in the input tensor.\n",
    "        #print(input[:,x])\n",
    "        rnn_out,rnn_hidden=rnn(input[:,x], rnn_hidden ) #each time step compute the output of the of the RNN\n",
    "        #print(\"size of rnn output\", rnn_out.size())\n",
    "        #print(\"size of traget\", target[:,x].size())\n",
    "        #rnn_out_new=rnn_out.view(batch_size, -1)\n",
    "        #print(\"size of rnn output\", rnn_out.size())\n",
    "        #print(\"size of traget\", target[:,x].size())\n",
    "        loss+=criterion(rnn_out,target[:,x]) # compute the loss over the output and the corresponding ground truth time step in target.\n",
    "    \n",
    "    loss=loss/chunk_size#loss should be averaged over all time steps.\n",
    "    \n",
    "    loss.backward() #call backward on the averaged loss \n",
    "    optimizer.step() # take an optimizer step.\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 619001,
     "status": "ok",
     "timestamp": 1606201661640,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "qgX0yOW5z12_",
    "outputId": "658157e1-7054-4211-91a2-74b36c9bce3c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training for 750 epochs...\n",
      "[0m 41s (50 6%) train loss: 1.9273, test_loss: 1.9712]\n",
      "Whe the wecth you beath the died it the and serted great hiss\n",
      "and do say irs it so he me coblow.\n",
      "\n",
      "SAFD \n",
      "\n",
      "[1m 21s (100 13%) train loss: 1.7240, test_loss: 1.7752]\n",
      "What him; who the king\n",
      "the but the take my lords,\n",
      "Thuch me contaito,\n",
      "I light you rence your to as is c \n",
      "\n",
      "[2m 0s (150 20%) train loss: 1.6449, test_loss: 1.7337]\n",
      "What, you must thou be not the lord:\n",
      "And do not the poor younger,\n",
      "And the wosed grean of his his loody \n",
      "\n",
      "[2m 41s (200 26%) train loss: 1.6108, test_loss: 1.6566]\n",
      "What where, I am say with my some vroiser well.\n",
      "\n",
      "MACHIO:\n",
      "Nay, if I have not now that course: I stake t \n",
      "\n",
      "[3m 22s (250 33%) train loss: 1.5616, test_loss: 1.6215]\n",
      "Which than say thee other a hearts,\n",
      "And come in a prayers and pice.\n",
      "\n",
      "HERMIO:\n",
      "Go thirm. There the cause \n",
      "\n",
      "[4m 3s (300 40%) train loss: 1.5902, test_loss: 1.6298]\n",
      "Where on your hade, sischarty;\n",
      "For thou hasses of my word, if What and fone.\n",
      "\n",
      "BOTIPHOLUS OF EPHESUS:\n",
      "W \n",
      "\n",
      "[4m 43s (350 46%) train loss: 1.5108, test_loss: 1.6155]\n",
      "Whorminess your received so, good grace\n",
      "A trose me report your in past done crudy,\n",
      "Now all the fire th \n",
      "\n",
      "[5m 24s (400 53%) train loss: 1.5358, test_loss: 1.5887]\n",
      "Whe a great for him to keep thee.\n",
      "\n",
      "ORTING:\n",
      "Whiles, and none else the barge and out the laws,\n",
      "These tro \n",
      "\n",
      "[6m 4s (450 60%) train loss: 1.5214, test_loss: 1.5724]\n",
      "Where I say, the pardon.\n",
      "\n",
      "LONGAVILLE:\n",
      "O Duke the blows? an not be of my most king,\n",
      "Thou should should  \n",
      "\n",
      "[6m 44s (500 66%) train loss: 1.5071, test_loss: 1.5824]\n",
      "Where! my lardon semons, feath\n",
      "That so rother, I'll resemn me.\n",
      "\n",
      "KING HENRY V:\n",
      "The son I am a speak to  \n",
      "\n",
      "[7m 25s (550 73%) train loss: 1.4938, test_loss: 1.5677]\n",
      "What stay her, by my freast:\n",
      "The bearest as man say the fine ever ships,\n",
      "And her words my body, and mu \n",
      "\n",
      "[8m 5s (600 80%) train loss: 1.4951, test_loss: 1.5960]\n",
      "Where is the great two tempt\n",
      "Or good distress' all thou earth to hords.\n",
      "\n",
      "LEONTES:\n",
      "I will were you thou \n",
      "\n",
      "[8m 45s (650 86%) train loss: 1.4402, test_loss: 1.5782]\n",
      "Where then, so measure.\n",
      "\n",
      "IAGO:\n",
      "I'll for your dweelight of this great one fear\n",
      "To will not be majesty,  \n",
      "\n",
      "[9m 26s (700 93%) train loss: 1.4637, test_loss: 1.5957]\n",
      "Whe even! Have stands, how they won and the\n",
      "peads yourself, here give me and this hungro's peace:\n",
      "That \n",
      "\n",
      "[10m 8s (750 100%) train loss: 1.4795, test_loss: 1.5409]\n",
      "Whee you him.\n",
      "\n",
      "JENTIO:\n",
      "You have find: to this my time we three bear\n",
      "That renderfeation to thy restore  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "rnn = RNN(n_characters, hidden_size, n_characters, model_type=model_type, n_layers=n_layers).to(device)\n",
    "rnn_optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "start = time.time()\n",
    "all_losses = []\n",
    "test_losses = []\n",
    "loss_avg = 0\n",
    "test_loss_avg = 0\n",
    "\n",
    "#n_epochs=1000\n",
    "print(\"Training for %d epochs...\" % n_epochs)\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    loss = train(rnn, *load_random_batch(train_text, chunk_len, batch_size), rnn_optimizer, criterion)\n",
    "    loss_avg += loss\n",
    "    \n",
    "    test_loss = eval_test(rnn, *load_random_batch(test_text, chunk_len, batch_size))\n",
    "    test_loss_avg += test_loss\n",
    "\n",
    "    if epoch % print_every == 0:\n",
    "        print('[%s (%d %d%%) train loss: %.4f, test_loss: %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss, test_loss))\n",
    "        print(generate(rnn, 'Wh', 100, device=device), '\\n')\n",
    "\n",
    "    if epoch % plot_every == 0:\n",
    "        all_losses.append(loss_avg / plot_every)\n",
    "        test_losses.append(test_loss_avg / plot_every)\n",
    "        loss_avg = 0\n",
    "        test_loss_avg = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 234,
     "status": "ok",
     "timestamp": 1606201719786,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "UsbBglR8z12_"
   },
   "outputs": [],
   "source": [
    "#save network\n",
    "torch.save(rnn.state_dict(), './rnn_generator.pth')\n",
    "#torch.save(rnn.state_dict(), \"./classification_model_final.pth\")\n",
    "torch.save(rnn_optimizer.state_dict(), \"./rnn_generator_optimizer.pth\")\n",
    "\n",
    "# rnn.load_state_dict(torch.load(\"./rnn_generator.pth\"))\n",
    "# optimizer.load_state_dict(torch.load(\"./rnn_generator_optimizer.pth\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0Q7jFQzSz12_"
   },
   "source": [
    "# Plot the Training and Test Losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 284
    },
    "executionInfo": {
     "elapsed": 421,
     "status": "ok",
     "timestamp": 1606201727491,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "kTP_C5eoz12_",
    "outputId": "552293de-b028-4c9f-8539-3da6f6d35394"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fd39023aba8>]"
      ]
     },
     "execution_count": 21,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxc5X3v8c9vtO+LJS+SZcsbNrbxggXGImyBYiCASRfICiVJ3dzQNrmhKU3ScpumNzc3SYG0KaVkaZoEEgjLLUsSSAjgODaLbLzJxjveZEuydln7zHP/OCNLtrXZHuloZr7v1+u85mya87Nf0nfOPOc5zzHnHCIiEv0CfhcgIiKRoUAXEYkRCnQRkRihQBcRiREKdBGRGJHo14ELCgpcaWmpX4cXEYlKGzZsOO6cKxxom2+BXlpaSkVFhV+HFxGJSmZ2YLBtanIREYkRCnQRkRihQBcRiREKdBGRGKFAFxGJEQp0EZEYoUAXEYkRURfoO4+18H9+sYMTnT1+lyIiMq5EXaAfqm/jP9bsY8fRZr9LEREZV6Iu0BcW5wBQWaVAFxHpL+oCfVJ2ChMyktl2pMnvUkRExpWoC3QzY35Rts7QRUROE3WBDl6zy67qFjp7gn6XIiIybkRloC8oyqYn5Nhd3ep3KSIi40ZUBvrCot4Lo2pHFxHpFZWBPi0/ncyURLYdUTu6iEivqAz0QMCYPyVbZ+giIv1EX6C/+CLMnMml6V3sONpCMOT8rkhEZFyIvkDPz4f9+ymv2UV7d5D9x3VhVEQEojHQL74YkpOZu28roDtGRUR6RV+gp6TAsmXkbdpAcmJAd4yKiIRFX6ADlJcT2FDBRRNSdIYuIhIWtYFOVxfXdVSx7UgTzunCqIhIdAb6ihUAXHpsJ80dPRxuaPe5IBER/0VnoE+ZAjNmMGv3FkAXRkVEIFoDHaC8nJx33iLBNASAiAhEeaDbsWNcntiqM3QREaI80AGub9ynrosiIkRzoC9cCJmZLD2yg5qWTmpbOv2uSETEV9Eb6ImJcOmllO7aBKgdXUQkegMdoLyc9B2VpHe1qx1dROLesIFuZiVm9qqZbTezSjP77AD7fNTMtpjZVjNbZ2aLR6fc05SXY6EQ1584qDN0EYl7IzlD7wHudc7NBy4D7jGz+aftsx+4yjl3EfBV4NHIljmIyy4D4P0Ne/SwCxGJe8MGunPuqHNuY3i+BdgBFJ+2zzrnXEN48Q1gaqQLHVBeHsyfz6KD2zlY30ZzR/eYHFZEZDw6qzZ0MysFlgJvDrHbJ4FfDvLzq82swswqamtrz+bQgysvp/jdzZgLsV3t6CISx0Yc6GaWCTwNfM45N2Bymtk1eIF+30DbnXOPOufKnHNlhYWF51LvmcrLSWpqZGbdEfVHF5G4NqJAN7MkvDB/zDn3zCD7LAK+B6xyztVFrsRhhG8wuqZ+t87QRSSujaSXiwHfB3Y45x4YZJ9pwDPAx51zuyJb4jAuuADy87nq+G62qaeLiMSxxBHscznwcWCrmW0Kr/sSMA3AOfcIcD8wAXjYy396nHNlkS93AGawYgXzt1Syt/YEHd1BUpMSxuTQIiLjybCB7pxbC9gw+3wK+FSkijpr5eVMePFFMk808+6xFpaU5PpWioiIX6L7TtFe4Xb0pVXv6sKoiMSt2Aj0Sy7BJSRQXr1LQwCISNyKjUDPyMCWLOHyml0aAkBE4lZsBDpAeTlz3tvO7qpGuoMhv6sRERlzMRXoyZ3tzKzay97aVr+rEREZczEV6ADLjuzQQF0iEpdiJ9BLSnBFRVx6dKfa0UUkLsVOoJth5eVccmynerqISFyKnUAHKC9nUt1RanfuJxRyflcjIjKmYi7QAebu38bB+jafixERGVuxFehLlxJKSfEujKodXUTiTGwFenIylJVRVvWu2tFFJO7EVqADgcsvZ+GxPew6EKEnIomIRImYC3TKy0kK9uDe3oBzujAqIvEj9gJ9xQoAZu/dQnVzp8/FiIiMndgL9IkT6Zg+g4s1lK6IxJnYC3Qg4fJylh3ZQaUCXUTiSEwGetIV76PwRCPVm7f7XYqIyJiJyUDvvcEo5e03fS5ERGTsxGagL1hAV3oGM3dvoeFEl9/ViIiMidgM9IQETiy9xGtH1w1GIhInYjPQgdQr38fc2gPs2n3Y71JERMZEzAZ62tVXkOBCtK19w+9SRETGRMwGOsuXEzIjfeNbflciIjImYjfQc3Konz6Hmbu3cKKzx+9qRERGXewGOtBxyXIuPvIuO440+l2KiMioi+lAz7zmSrI7T3Bk3Ua/SxERGXUxHeg5114JQPfatT5XIiIy+mI60G3OHJozc8l5p8LvUkRERl1MBzpmVC+8mFl7ttDZE/S7GhGRURXbgQ70LL+MmfVH2LfjgN+liIiMqpgP9NxrrwKg9tev+1yJiMjoivlAn/T+99EdSIB16/wuRURkVMV8oAcy0jlQcgH5W3RhVERiW8wHOsDxRcuYtX8HwU4NpSsisSsuAp3yctJ6Oql6fb3flYiIjJq4CPSC668GoPE3ujAqIrErLgJ9+uK5VGUXkvCmhtIVkdg1bKCbWYmZvWpm282s0sw+O8A+Zmb/YmZ7zGyLmV08OuWem6SEAHtmX8TErRrTRURi10jO0HuAe51z84HLgHvMbP5p+9wIzAlPq4F/j2iVEdC05BIKGqpxhw75XYqIyKgYNtCdc0edcxvD8y3ADqD4tN1WAT9ynjeAXDObEvFqz0PC+y4HoE7t6CISo86qDd3MSoGlwJunbSoG+p/6HubM0MfMVptZhZlV1NbWnl2l56noqstoT0yh9bdrxvS4IiJjZcSBbmaZwNPA55xzzedyMOfco865MudcWWFh4bm8xTmbV5LPlilzSHn79M8iEZHYMKJAN7MkvDB/zDn3zAC7HAFK+i1PDa8bN1KTEtg/ZxGFeyqhvd3vckREIm4kvVwM+D6wwzn3wCC7PQfcGe7tchnQ5Jw7GsE6I6J12aUkBoNQoWEARCT2jOQM/XLg48D7zWxTeLrJzD5tZp8O7/MLYB+wB/gu8JnRKff8pFzhXRhtfVXt6CISexKH28E5txawYfZxwD2RKmq0zFkwg735xeS8/jsy/S5GRCTC4uJO0V7zi7LZWHQhGRveAuf8LkdEJKLiKtCzU5PYf8Ei0poaYO9ev8sREYmouAp0gM5LLvVm9MALEYkxcRfo+ZcspTk5na7frfW7FBGRiIq7QF8wNZd3iufR/bvf+12KiEhExV+gF+WwofhC0nftgKYmv8sREYmYuAv0wqwU9s1ZhDkHb2oYABGJHXEX6AChSy4haAFdGBWRmBKXgT5rVhE7C6cT/L3a0UUkdsRloPe2o/PmmxAM+l2OiEhExGmgZ7Oh+EISWlpg+3a/yxERiYi4DPSpeWnsnrnQW1A7uojEiLgMdDMjZ8FcGrPyFOgiEjPiMtABFhTn8PaUeTgFuojEiLgN9IXFObxdNA/bswdqavwuR0TkvMVtoPdeGAVg/Xp/ixERiYC4DfQZBZnsKZlLMDFR7egiEhPiNtATAsaskgnsLZmrM3QRiQlxG+jg3WC0btJc3NtvQ1eX3+WIiJyXuA70hcXZvDF5LtbRAZs2+V2OiMh5ietAX1CUw8aied6C2tFFJMrFdaDPmZRJQ24BTROLFOgiEvXiOtBTEhOYMzGL7TMWwu9/D875XZKIyDmL60AHrz/6moI5UFUFhw75XY6IyDmL+0BfWJzjBTqo2UVEolrcB/qComzenTiDYFqa+qOLSFSL+0C/cEo2oYQEquYuhtdfVzu6iEStuA/0jJREZhRksGbRlbB5M3z7236XJCJyTuI+0MHrj/7wvOvhttvgC1+AtWv9LklE5Kwp0IGFRdkcaeqg4d8ehRkz4Pbb4dgxv8sSETkrCnS8M3SA7W0GTz8NjY1wxx3Q0+NzZSIiI6dAx+vpArDtSBNcdBE8+iisWQNf/KLPlYmIjJwCHcjLSKY4N43KqmZvxcc+Bp/5DHzrW94Zu4hIFFCgh80vymbjwQZ6giFvxQMPwPLlcPfdsHOnv8WJiIyAAj3sg0uLOdzQziOv7/VWpKTAz3/uvf7hH0Jrq78FiogMQ4EedtNFU7hlcREP/Wa315YOUFICP/sZvPsu/Nmf6aYjERnXFOj9fHXVAgoyU/ifT2yiozvorbz2Wvinf/KC/V//1d8CRUSGoEDvJzc9mW/+ySJ217TyjV/1aze/7z645Ra4915vmF0RkXFo2EA3sx+YWY2ZbRtke46ZPW9mm82s0szujnyZY+eKOYXctWI6P/j9ftbtOe6tDATgRz+C6dO9m46qq/0tUkRkACM5Q/8hcMMQ2+8BtjvnFgNXA/9sZsnnX5p//vbGC5lZmMFf/3wzTe3d3srcXK8LY0MDfOhDuulIRMadYQPdObcGqB9qFyDLzAzIDO8b1WmXlpzAg7cvobqlk688V9m3YfFieOQReO01+PKXfatPRGQgkWhD/w5wIVAFbAU+65wLDbSjma02swozq6itrY3AoUfP4pJc/uKa2TzzzhF+ufVo34Y774RPfxq+8Q149ln/ChQROU0kAn0lsAkoApYA3zGz7IF2dM496pwrc86VFRYWRuDQo+sv3j+bRVNz+NKzW6lp7ujb8NBDcMklcNddsGuXfwWKiPQTiUC/G3jGefYA+4F5EXhf3yUlBHjg9iW0dQW57+ktuN5+6Ckp8NRTkJwMf/RHcOKEv4WKiBCZQD8IXAtgZpOAucC+CLzvuDB7YiZfvHEer+6s5adv9XuI9LRp8PjjUFkJq1frpiMR8d1Iui3+FFgPzDWzw2b2STP7tJl9OrzLV4FyM9sKvALc55w7Pnolj707V5TyvtkF/NOL2zlQ1+9s/Prr4R//0Qv2hx/2r0AREcCcT2eWZWVlrqKiwpdjn4ujTe2sfHANcyZl8eSfryAhYN6GUAhWrYKXXvKG3L3sMn8LFZGYZmYbnHNlA23TnaIjNCUnja/etpANBxr6BvCCvpuOSkrgj/8Yamr8K1JE4poC/SzcuriIDyyawkO/2UVlVVPfhrw876ajujr48Id105GI+EKBfhbMjP9920Ly0pNPHcALYMkSrx39t7+Fv/97/4oUkbilQD9LuenJfOOPF7GrupV/fvm0B1/cfbc3zO7Xvw7//d/+FCgicUuBfg6unjuRj102je+t3c8b++pO3fgv/wLLlnl3lO7e7U+BIhKXFOjn6Es3XUjphAzufXIzzR3dfRtSU7329MRE3XQkImNKgX6O0pMT+efbF3O0qZ2vPLf91I3Tp3t907dt88Z90U1HIjIGFOjn4eJpedxzzWye3niYX207durGlSvhH/4BfvITuP9+6OrypUYRiR8K9PP0V9fOYWFxtjeAV0vHqRv/7u/gIx/xHmF30UXwi1/4U6SIxAUF+nlKSgjw4O1LaO3s4YtPb+WUO28DAXjssb4g/8AH4OabdbFUREaFAj0C5kzK4r4b5vHKuzU88fahM3e48UbYuhW++U1veIAFC7znlLa0jH2xIhKzFOgRcnd5KeWzJvDVF7ZzsK7tzB2Sk+Gv/9obP/2jH/UekHHBBd6wAaEBnwciInJWFOgREggY3/yTxQTM+PyTmwiGBunZMnky/Od/whtveEPw3nUXXH45vP322BYsIjFHgR5BxblpfGXVAioONPDommGGhF++HNav98J9/35v+ZOfhOrqsSlWRGKOAj3CPri0mJsumswDv97J9qrmoXcOBOBP/9Rrhrn3Xvjxj71mmAcfhO7uoX9WROQ0CvQI8wbwuojc9GQ+/+QmOnuCw/9QdrZ3wXTrVigvh89/HhYtgpdfHv2CRSRmKNBHQV6GN4DXu8daeODls3iI9Ny5XhfH55/3huBdudJ7eMbevcP/rIjEPQX6KLlm7kQ+snwaj/5uH+v31g3/A73MvL7q27Z5ozb+9rcwfz58+cvQ2jp6BYtI1FOgj6Ivhwfw+sQP3+bFLUfP7odTUry+6jt3wh13wNe+5p3BP/64xoYRkQEp0EdRRkoiT/75ChYUZXPP4xv51ks7CQ3WnXEwRUVeX/V162DKFK8P+xVXwMaNo1O0iEQtBfooK8xK4fE/u4wPX1rCd17dw+ofV9DScQ49WFasgLfegu9/3+sVU1YG110H3/0uHD8e+cJFJOoo0MdAcmKAr33wIr66agGv7azlgw+vY1/tObSHBwLwiU94Y8Hcfz8cPAirV3s3K91wg9envbEx8v8AEYkK5nxqjy0rK3MVFRW+HNtP6/fWcc/jG+kOhvjXDy/l6rkTz/3NnINNm+CJJ7zpvfcgKcnrHXPHHXDrrV6XSBGJGWa2wTlXNuA2BfrYO1Tfxuofb2DnsWbuu2Eeq6+ciZmd35s65w0f8MQT8OSTcPiwd2H1ppu8cL/5ZsjIiMw/QER8o0Afh9q6evjCz7fw4tajrFpSxP/9o0WkJiVE5s1DIW+smCeegJ//HI4ehbQ0L9TvuMML+bS0yBxLRMaUAn2ccs7x8Gt7+dbLO1lYlMN/fHwZRbkRDtpgENau9cL9qaegthYyM73mmDvu8JpnUlIie0wRGTUK9HHuN9ur+dwTm0hNCvDIx5ZRVpo/Ogfq6YHXXvPC/ZlnoL4ecnLgttu8cL/uOq8NXkTGLQV6FNhT08Kn/quCI43t/OOqhXz40mmje8DubnjlFS/cn30WmpogP98bauC66+Cqq6C4eHRrEJGzpkCPEk1t3fzlz95hza5a7lwxnb+/eT5JCWPQs7Sz0xsI7Ikn4IUXvHAHmD3bC/arr/ZeS0pGvxYRGZICPYoEQ45v/Opd/mPNPpbPyOfhj17MhMwxbOMOBmHzZq9p5vXXvUfm9fZtnzGjL9yvugpKS8euLhEBFOhR6dl3DnPf01spzEzh0TuXsaAox59CgkFvWN/XX++b6uu9bdOnn3oGP2OGN7iYiIwaBXqU2nK4kdU/2kBTezff/JNF3LyoyO+SvC6RlZVesPeexfcOPTB1al+4X301zJqlgBeJMAV6FKtp6eB//GQjGw40cM81s7j3D+YSCIyjkHQOduzoC/fXXoOaGm9bUVFf88wVV3ijRSZEqK+9SJxSoEe5zp4g9/+/Sp6oOMS18yby4IeWkJ06TrsXOucN+dv/DP5oeOjgzExYutQbWKx3mj3bG6NGREZEgR4DnHP8+I0DfOX57ZROSOe7d5YxszDT77KG5xzs2eM9ELuiwpveeQc6Orzt2dmwbNmpIa+2eJFBKdBjyPq9dXzmsQ30hByfuHwGty4pYlY0BHt/PT2wfXtfwFdUeD1rurq87Xl5XrD3D/pp0xTyIijQY86h+ja+9OxW1u45jnMwf0o2ty4p4pbFRRRHeuiAsdLV5V1s7R/yW7Z44Q9QUHDqWXxZmddGr5CXOKNAj1HVzR28sOUoz22uYvMhr6942fQ8bl1SxI0Lp1CYFeVjtHR0eF0m+4d8ZaXXlRK8ceBLS70z+t4pN3fo5awsfQhIVDuvQDezHwA3AzXOuYWD7HM18BCQBBx3zl01XFEK9Mg6WNfG81uqeG5TFTurWwgYXD67gFsWF7FywWRy0sbpRdSz1dbmNc/0BnxVFTQ09E2NjUM/czUQGD708/K8YRAKCmDChL5Jg5jJOHC+gX4l0Ar8aKBAN7NcYB1wg3PuoJlNdM7VDFeUAn307DzWwvObq3hucxUH69tITghw1dxCbl1cxHUXTiItOYa7DoZC0NJyasCfHvhDLXcP8XjAjAwv2E8P+qHWZWbqG4FE1Hk3uZhZKfDCIIH+GaDIOfd3Z1OUAn30OefYfLiJ5zdX8cKWKqqbO0lPTuAP5k/ilkVFXHlBIcmJ6jJ4knPQ3u7dCVtfD3V13k1TdXVnTv3XNzQM/p5JSaeGfG6uN+XkDDzffzknR6NfyhlGO9B7m1oWAFnAt51zPxrkfVYDqwGmTZu27MCBAyP8J8j5CoYcb+2v57nNVfxy21Ea27rJSUvixoWTuXVxEctnTiBhPN2wFE2CQS/URxL+TU3e1NjovQ7395eePnDY95+fONEbhqG01BtATR8CMW20A/07QBlwLZAGrAc+4JzbNdR76gzdP93BEGt3H+e5zVW8XHmME11BCrNS+MBFU7hlcRFLS3LH192osSoUgtZWL9x7p96wH2q+//LpTUSBgDfscWlpX8j3n0pKIDl5zP+pEjlDBXpiBN7/MFDnnDsBnDCzNcBiYMhAF/8kJQS4Zt5Erpk3kfauIK/urOG5TVU8/tZBfrjuPSZlp3D9/MmsXDCZ5TPzx2YI33gUCHg3VmVne/3sz1ZvE9GxY3DggPeQ8Pfe65v/3e/g8ce9D45eZl53z/4h3z/4p03Txd8oFokz9AuB7wArgWTgLeBDzrltQ72nztDHn+aObl7ZUc1L26p5bVcNHd0hctKSuPbCiaxcMJkr5xTG9gXVWNTdDUeOnBr0/adDh/q6gfYqKvJCPi/Pe/Zsaur5Tae/R3q6LhSfh/Pt5fJT4GqgAKgG/hdemznOuUfC+3wBuBsIAd9zzj00XFEK9PGtvSvImt21vFR5jN9sr6a5o4e0pASuuqCQlQsn8f55k2KnK2Q86+nxun6eHvQHDkBzs3cvQEeH902g/3z/s/6z1fvNJCen7/X0+aG29c4nRqKBIfroxiI5L93BEG/uq+elymO8vP0Y1c2dJAaMFbMmsHLBZK6fP4mJ2al+lyljqaenL+D7T/2Df6Cprc3rVtrU5H1g9F4kPn15qO6jvdLTT/1gyMrqmzIzh14+fV1aWtR8a1CgS8SEQo5Nhxt5qfIYL207xnt1bZjB0pJcbljotbtPn5Dhd5kSzZzzHos4UNAPNd/S4k2trX3znZ0jO2YgMLLgH+qDov/8KDYrKdBlVDjn2FXd6oV75TEqq5oBmDc5i+sXTGblgknMn5KNRcmZj8Sg7u5TA36g0B9s3UDLvaOEDsds6A+CVavg9tvP6Z+kQJcxcai+zWuWqazm7QP1OAcl+WmsnD+Z6+ZPYum0XFISdVFVolj/D4j+gT/Uh8FA21avhr/5m3MqQYEuY662pZPf7Kjmpcpj/H7PcbqDjtSkAGXT81kxawLlsyZwUXEOieoSKXJWFOjiq+aObt7YW8e6vXWs31vHzuoWADJTElk+wwv4FbMmcOHkbN3QJDKM0b6xSGRI2alJXL9gMtcvmAzA8dZO3tjXF/CvvOuN5ZabnsSKmd7Z+4pZE5hVmKn2d5GzoDN08d3RpnbWh8/g1+05TlWTd+GpMCuF8nDzTPmsAkry032uVMR/anKRqOGc42B928mz93V76zje6nU9K85N88J99gRWzCxgco76vkv8UaBL1HLOsaem9WTAr99XR1O7d9PJzIIMFhTnMC0/jZK8dKblp1OSn86UnFRdbJWYpUCXmBEKObYfbWb93jre2FfH7ppWqhrb6Qn1/R4nBoyi3LRwwKdRku+F/bT8dEry0slNT1LbvEQtBbrEtJ5giKNNHRyqb+NQQxsH69s4WN/uLde3UXei65T9s1ISKQmHfW/QTw2/FuemkZqkvvIyfqmXi8S0xIRAOKAHvmja2tnD4YY2DtZ5Ye8Ffzt7a0/w2s5aOnv6Bpoyg8nZqcwqzGTOpEwumJTFBZMymTMpi+xUDUYm45sCXWJeZkoi8yZnM29y9hnbQiHH8dbO8Fl9G4fq2zlQd4LdNa387K1DtHf3DS07OTv1jJCfMzGTLAW9jBMKdIlrgYAxMTuVidmplJXmn7ItFHIcbmhnV3ULu2pa2F3dyq7qFn7yxoFTzuqLclKZ0y/kLwgHfUaK/rxkbOk3TmQQgYAxbUI60yakc938SSfXB0OOQ/Vt7KpuYXeNF/K7qltZv6+Orn5BX5ybxgXhM/o5k7KYPTGTqXlpTMhI1kVZGRUKdJGzlBAwSgsyKC3I4PoFfet7giEO1rexq7qV3dUt7KrxXteGx7LplZIYoDg3jaLcNIpyU8OvaUwNv07OSdWFWTknCnSRCElMCDCzMJOZhZncsHDyyfU9wRDv1bWxt9brYulNHRxpbOe1nbXUtJw5ZndBZgrF4bDvC//e+VTydZYvA1Cgi4yyxIQAsydmMnti5oDbO3uCHGvyAr6qsYOqxnaONLRT1dTOzuoWXt3pPd+1v9SkwMmAn5iVSl56EnkZyeSlJ5OXnkRuejJ5GUnkpSeTm56kYYvjhAJdxGcpiQlMn5Ax6JOenHM0tHVT1djO4Yb2vrP8pnaONHawp+Y4DW1dZ4R+fxnJCaeEfP/gz8/wQv/k+vA+6ckJ+hYQZRToIuOcmZGf4QXvwuKcQffr6A7S0NZFw4luGtu6qG/roqGtm8YT3mtDW1d46uZgfRsNJ7po7ugZ9P1SEgMUZKZQkJVCYWYyBZkpTAi/9k6FWd5yTpruvh0PFOgiMSI1KYEpOWlMyUkb8c/0BEM0tnsfAA1t3dSf6Do5X9fayfHWLo63dnKksYPNh5uoP9FFMHTm3eWJATsj7Auykik87UOgMCuF/PRkjXs/ShToInEsMSFwMmxHIhRyNLR1nQz64/1C/3iLt1x3oovd1S0cb+2iK3hmM1ByQoApuakU5fRe6O3r6dPb8yc9WdF0LvS/JiIjFggYEzJTmJCZwlyyhtzXOUdzR8/JsK870UVtSydVTX0Xf9ftPU51cwenn/TnpicNGvjFuWkUZqWQoLP8MyjQRWRUmBk5aUnkpCUxq3DgHj4A3cEQ1c0dHG0K9/Dp17XzcEMbb+6vo+W0tv7EgDEpO/VkN86CzBTMwDlw9L46escedM6dsb53md7lAbYlBCAtKYHUpARSkhLC8wFS+833rfeWT84nJpCaHCA5ITBm1xcU6CLiq6SEAFPz0pmaN/gTqZo7ujna2Bf4R5v6+vJXHGigrtUbUdMMDO/DxAD6L5+2zcI79K0H49T9ekIhOrpDdHQHTxnu4WyYQWpiAmnJCaQmeh8GH1k+jU9dMfOc3m8oCnQRGfeyU5PInpzE3MlDN/OMplDI0RUM0d4VpKMnSEd3//neKXTytT28rrM7GJ73trV3B0d8zeJsKdBFREYgEDBSAwnjelgGPadLRCRGKNBFRGKEAl1EJEYo0EVEYoQCXUQkRijQRURihAJdRCRGKJ/ik3QAAARmSURBVNBFRGKEOXfmUJhjcmCzWuDAOf54AXA8guWMtmiqN5pqheiqN5pqheiqN5pqhfOrd7pzrnCgDb4F+vkwswrnXJnfdYxUNNUbTbVCdNUbTbVCdNUbTbXC6NWrJhcRkRihQBcRiRHRGuiP+l3AWYqmeqOpVoiueqOpVoiueqOpVhileqOyDV1ERM4UrWfoIiJyGgW6iEiMiLpAN7MbzGynme0xs7/1u57BmFmJmb1qZtvNrNLMPut3TSNhZglm9o6ZveB3LUMxs1wze8rM3jWzHWa2wu+ahmJm/zP8e7DNzH5qZql+19Sfmf3AzGrMbFu/dflm9msz2x1+zfOzxl6D1PrN8O/CFjN71sxy/ayxv4Hq7bftXjNzZlYQiWNFVaCbWQLwb8CNwHzgw2Y239+qBtUD3Oucmw9cBtwzjmvt77PADr+LGIFvA79yzs0DFjOOazazYuCvgDLn3EIgAfiQv1Wd4YfADaet+1vgFefcHOCV8PJ48EPOrPXXwELn3CJgF/DFsS5qCD/kzHoxsxLgeuBgpA4UVYEOXArscc7tc851AT8DVvlc04Ccc0edcxvD8y14gVPsb1VDM7OpwAeA7/ldy1DMLAe4Evg+gHOuyznX6G9Vw0oE0swsEUgHqnyu5xTOuTVA/WmrVwH/FZ7/L+C2MS1qEAPV6px72TnXE158A5g65oUNYpD/W4AHgb8BItYzJdoCvRg41G/5MOM8JAHMrBRYCrzpbyXDegjvF+zcHm8+dmYAtcB/hpuHvmdmGX4XNRjn3BHgW3hnYkeBJufcy/5WNSKTnHNHw/PHgEl+FnMWPgH80u8ihmJmq4AjzrnNkXzfaAv0qGNmmcDTwOecc81+1zMYM7sZqHHObfC7lhFIBC4G/t05txQ4wfhpDjhDuO15Fd4HURGQYWYf87eqs+O8/s3jvo+zmX0Zr7nzMb9rGYyZpQNfAu6P9HtHW6AfAUr6LU8NrxuXzCwJL8wfc84943c9w7gcuNXM3sNrynq/mf3E35IGdRg47Jzr/cbzFF7Aj1fXAfudc7XOuW7gGaDc55pGotrMpgCEX2t8rmdIZvanwM3AR934vsFmFt6H++bw39tUYKOZTT7fN462QH8bmGNmM8wsGe/C0nM+1zQgMzO8Nt4dzrkH/K5nOM65LzrnpjrnSvH+X3/rnBuXZ5HOuWPAITObG151LbDdx5KGcxC4zMzSw78X1zKOL+L28xxwV3j+LuC/faxlSGZ2A15z4a3OuTa/6xmKc26rc26ic640/Pd2GLg4/Ht9XqIq0MMXPf4CeAnvD+JJ51ylv1UN6nLg43hnupvC001+FxVD/hJ4zMy2AEuAr/lcz6DC3ySeAjYCW/H+7sbVrepm9lNgPTDXzA6b2SeBrwN/YGa78b5lfN3PGnsNUut3gCzg1+G/tUd8LbKfQeodnWON728mIiIyUlF1hi4iIoNToIuIxAgFuohIjFCgi4jECAW6iEiMUKCLiMQIBbqISIz4/2Tvb569jPFKAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(all_losses)\n",
    "plt.plot(test_losses, color='r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v_aPt1LMz12_"
   },
   "source": [
    "# Evaluate text generation\n",
    "\n",
    "Check what the outputted text looks like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1279,
     "status": "ok",
     "timestamp": 1606201746205,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "-gjbCojvz12_",
    "outputId": "03f5f558-ebea-4cf9-9613-4f15552bb322"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The will benest his mans me,\n",
      "Hame that will not be a good the provity.\n",
      "\n",
      "TRIA:\n",
      "He dew that with merry a man for the strange.\n",
      "I then to the rash, so must came of the chamuness, and that I'll treason dost\n",
      "the heaven! how there. The run of these thou instress\n",
      "Which wast true come come on my tongue.\n",
      "\n",
      "KATHARINE:\n",
      "My lord, the crown English am a thanks, and I\n",
      "have you weep you galls. O, I wast thy change;\n",
      "And go turn of my love to the master.'\n",
      "\n",
      "ARCHBISHOP OF YORK:\n",
      "I'll find by dogs, noble.\n",
      "\n",
      "SAMLET:\n",
      "The matter were be true and treason\n",
      "Free supples'd best the soldiered.\n",
      "\n",
      "TITUS ANDRONICUS:\n",
      "I ever a bood;\n",
      "But one a stand have a court in thee: which man as thy break on my bed\n",
      "'As oath a women; there and shake me; and whencul, comes the house\n",
      "For them he wall; and no live away. Fies, sir.\n",
      "\n",
      "TRANIO:\n",
      "The congrain upon a dream, for he did pluck'd\n",
      "That they should cheen gate you gate forth in dicess the\n",
      "proclay read sative the more and like me weld my life.\n",
      "\n",
      "PLETEY:\n",
      "There that I never my reme, bay hath be \n"
     ]
    }
   ],
   "source": [
    "print(evaluate(rnn, prime_str='Th', predict_len=1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6mwFrJahz12_"
   },
   "source": [
    "# Hyperparameter Tuning\n",
    "\n",
    "Some things you should try to improve your network performance are:\n",
    "- Different RNN types. Switch the basic RNN network in your model to a GRU and LSTM to compare all three.\n",
    "- Try adding 1 or two more layers\n",
    "- Increase the hidden layer size\n",
    "- Changing the learning rate\n",
    "\n",
    "**TODO:** Try changing the RNN type and hyperparameters. Record your results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 240,
     "status": "ok",
     "timestamp": 1606202277796,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "iWWxKfmez12_"
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "n_epochs = 2000\n",
    "hidden_size = 150\n",
    "n_layers =2\n",
    "learning_rate = 0.01 #0.1 #0.01 #0.005\n",
    "#model_type = 'gru'\n",
    "# model_type = 'rnn'\n",
    "model_type = 'lstm'\n",
    "print_every = 50\n",
    "plot_every = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 235,
     "status": "ok",
     "timestamp": 1606202280270,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "Bze1BX_Iz13A"
   },
   "outputs": [],
   "source": [
    "def train(rnn, input, target, optimizer, criterion):\n",
    "    \"\"\"\n",
    "    Inputs:\n",
    "    - rnn: model\n",
    "    - input: input character data tensor of shape (batch_size, chunk_len)\n",
    "    - target: target character data tensor of shape (batch_size, chunk_len)\n",
    "    - optimizer: rnn model optimizer\n",
    "    - criterion: loss function\n",
    "    \n",
    "    Returns:\n",
    "    - loss: computed loss value as python float\n",
    "    \"\"\"\n",
    "    loss =0\n",
    "    \n",
    "    ####################################\n",
    "    #          YOUR CODE HERE          #\n",
    "    ####################################\n",
    "    batch_size=input.size()[0] \n",
    "    chunk_size=input.size()[1] \n",
    "    #print(chunk_size)\n",
    "    rnn_hidden = rnn.init_hidden(batch_size, device=device) # initialize a hidden layer representation using your RNN's init_hidden function\n",
    "    rnn.zero_grad()  # set the model gradients to zero\n",
    "    \n",
    "    for x in range(chunk_size): # loop over each time step (character) in the input tensor.\n",
    "        #print(input[:,x])\n",
    "        rnn_out,rnn_hidden=rnn(input[:,x], rnn_hidden ) #each time step compute the output of the of the RNN\n",
    "        #print(\"size of rnn output\", rnn_out.size())\n",
    "        #print(\"size of traget\", target[:,x].size())\n",
    "        #rnn_out_new=rnn_out.view(batch_size, -1)\n",
    "        #print(\"size of rnn output\", rnn_out.size())\n",
    "        #print(\"size of traget\", target[:,x].size())\n",
    "        loss+=criterion(rnn_out,target[:,x]) # compute the loss over the output and the corresponding ground truth time step in target.\n",
    "    \n",
    "    loss=loss/chunk_size#loss should be averaged over all time steps.\n",
    "    \n",
    "    loss.backward() #call backward on the averaged loss \n",
    "    optimizer.step() # take an optimizer step.\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1725154,
     "status": "ok",
     "timestamp": 1606204008444,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "NddmE2mrz13A",
    "outputId": "af913c78-7ab3-4355-e642-5420179fbf6e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN(\n",
      "  (embedding_object): Embedding(100, 150)\n",
      "  (dn_object): LSTM(150, 150, num_layers=2)\n",
      "  (linear): Linear(in_features=150, out_features=100, bias=True)\n",
      ")\n",
      "Training for 2000 epochs...\n",
      "[0m 44s (50 2%) train loss: 2.2568, test_loss: 2.2478]\n",
      "Whar unthe dor\n",
      "mistass the some the! Tarler I moovoby rofod. Whear are citinotarlon me,\n",
      "My tnest hom.\n",
      " \n",
      "\n",
      "[1m 28s (100 5%) train loss: 1.9382, test_loss: 1.9670]\n",
      "Whund to a frovey for ox sock withich and for hife, lown,\n",
      "On Worcham, her word?\n",
      "\n",
      "ROT:\n",
      "Ay of hersince s \n",
      "\n",
      "[2m 11s (150 7%) train loss: 1.7706, test_loss: 1.8336]\n",
      "What herely evere your are to\n",
      "lathe this marring, as will mad,\n",
      "Nor do with from mift\n",
      "strave tale thing \n",
      "\n",
      "[2m 54s (200 10%) train loss: 1.6878, test_loss: 1.7490]\n",
      "What be use as the grace of the base can prince,\n",
      "His that I once\n",
      "grow, do the enemperion, she more to  \n",
      "\n",
      "[3m 37s (250 12%) train loss: 1.6231, test_loss: 1.7052]\n",
      "Where him are you canst\n",
      "And she wront do great yey and eathers\n",
      "Your abdess it thousand out and not the \n",
      "\n",
      "[4m 21s (300 15%) train loss: 1.5808, test_loss: 1.6612]\n",
      "Who to loves, not.\n",
      "\n",
      "TITUS ANDRONICUS:\n",
      "There's it hered to me lovesion now,\n",
      "He well farewell, for so th \n",
      "\n",
      "[5m 4s (350 17%) train loss: 1.5867, test_loss: 1.6273]\n",
      "Wh's guard threate she\n",
      "have ever will be love of the seeedness.\n",
      "Nay, and my send you are treason!\n",
      "What \n",
      "\n",
      "[5m 49s (400 20%) train loss: 1.5436, test_loss: 1.6458]\n",
      "Where is did some atten in it shall too.\n",
      "They find a man now not his failly\n",
      "to be shupan?\n",
      "Or free succ \n",
      "\n",
      "[6m 33s (450 22%) train loss: 1.5653, test_loss: 1.6532]\n",
      "When was this before, and in his both is a; her.\n",
      "\n",
      "WARWICK:\n",
      "That with my heart, Rome and a fure to your \n",
      "\n",
      "[7m 16s (500 25%) train loss: 1.4941, test_loss: 1.5937]\n",
      "What is dead. O virtue to doing\n",
      "I say and it to the saulty,\n",
      "If like allation go to it, no supple\n",
      "As my \n",
      "\n",
      "[8m 1s (550 27%) train loss: 1.4899, test_loss: 1.5902]\n",
      "Where you'ld find so eniin enemy his\n",
      "served than heart of the lady: and make so, thousand thine\n",
      "And tr \n",
      "\n",
      "[8m 43s (600 30%) train loss: 1.4549, test_loss: 1.5888]\n",
      "Who we are it is my good tarried.\n",
      "\n",
      "TRANIO:\n",
      "Where, if she satirely hand that he vining being-drawgatien \n",
      "\n",
      "[9m 25s (650 32%) train loss: 1.4372, test_loss: 1.5399]\n",
      "What, the king too lost,--and this trift you of them:\n",
      "Where she is an honest by thy will will boy and  \n",
      "\n",
      "[10m 8s (700 35%) train loss: 1.4671, test_loss: 1.5415]\n",
      "When I such ere not be feuse on taking knot\n",
      "the tontune\n",
      "The see again, gentle whereoven his set them l \n",
      "\n",
      "[10m 51s (750 37%) train loss: 1.4236, test_loss: 1.5733]\n",
      "Why which is a good thrush our sent your bloody should\n",
      "caby, sir: I am where with all us.\n",
      "\n",
      "LADY MALCUL \n",
      "\n",
      "[11m 32s (800 40%) train loss: 1.4346, test_loss: 1.5570]\n",
      "What will be bring to him.'\n",
      "\n",
      "IAGO:\n",
      "Degan! Master, to beat dishdient\n",
      "for thine take of the tatures.\n",
      "\n",
      "CO \n",
      "\n",
      "[12m 15s (850 42%) train loss: 1.4318, test_loss: 1.5636]\n",
      "What can find it yet come:\n",
      "Then leave the sear is and of law:\n",
      "There as the Forth! I might above and do \n",
      "\n",
      "[12m 59s (900 45%) train loss: 1.4155, test_loss: 1.5372]\n",
      "Why, to eat his upon this honour,\n",
      "They day, not to be--\n",
      "What, for you welcome, I strench but the forme \n",
      "\n",
      "[13m 43s (950 47%) train loss: 1.4170, test_loss: 1.5296]\n",
      "Who stand singly lord to Talbot?\n",
      "\n",
      "HAMLET:\n",
      "Well, my lord, granding shall I think well as their tender o \n",
      "\n",
      "[14m 27s (1000 50%) train loss: 1.4236, test_loss: 1.5353]\n",
      "Who say her presently fair to you did deed\n",
      "Befear my kind of gentlemen, I muddige,\n",
      "be my hands?\n",
      "\n",
      "DAUPH \n",
      "\n",
      "[15m 9s (1050 52%) train loss: 1.4151, test_loss: 1.5077]\n",
      "Who the keep; that we were I go to makes me home and done\n",
      "more danger well.\n",
      "\n",
      "TIMON:\n",
      "And I have dint yo \n",
      "\n",
      "[15m 52s (1100 55%) train loss: 1.3968, test_loss: 1.5172]\n",
      "What cannot look'st with wars such a prison's\n",
      "hang me like a thing like a very night at mad?\n",
      "\n",
      "OTHELLO: \n",
      "\n",
      "[16m 35s (1150 57%) train loss: 1.4115, test_loss: 1.5110]\n",
      "Who like at me a man hell; shes in the mother:\n",
      "If you say I have and the plead.\n",
      "\n",
      "ANTIPHOLUS OF SYRACUS \n",
      "\n",
      "[17m 19s (1200 60%) train loss: 1.3564, test_loss: 1.5019]\n",
      "What conclude as thou hadst strong,\n",
      "And tatfore should make the good; therefore I oditing the place\n",
      "I  \n",
      "\n",
      "[18m 2s (1250 62%) train loss: 1.4141, test_loss: 1.5343]\n",
      "What find Caesar'd accused stains, become the great that serves\n",
      "and my father's faith, and all begun a \n",
      "\n",
      "[18m 45s (1300 65%) train loss: 1.3970, test_loss: 1.4959]\n",
      "What we is not in Christian!\n",
      "\n",
      "MISTRESS OF AUMERLE:\n",
      "Good my lord, my better too littlemen,\n",
      "By surmen's  \n",
      "\n",
      "[19m 29s (1350 67%) train loss: 1.3997, test_loss: 1.5177]\n",
      "Where is he counselling man. If I would not news\n",
      "Of our head and a moon in the crown, on the gods\n",
      "With \n",
      "\n",
      "[20m 12s (1400 70%) train loss: 1.4090, test_loss: 1.4994]\n",
      "Who show when nothing have hears and myself to attain\n",
      "and the speakers given our head, that you, sir,\n",
      " \n",
      "\n",
      "[20m 54s (1450 72%) train loss: 1.3759, test_loss: 1.5362]\n",
      "When he is stand, well.\n",
      "\n",
      "SICINIUS:\n",
      "I do fear for the spirits from his likes on thy service,\n",
      "And to be  \n",
      "\n",
      "[21m 38s (1500 75%) train loss: 1.3725, test_loss: 1.5049]\n",
      "What watch that come the fair way to the action\n",
      "To happose ancient make a courage\n",
      "Before this most ans \n",
      "\n",
      "[22m 20s (1550 77%) train loss: 1.3601, test_loss: 1.5156]\n",
      "Which seeks should seem follow'd to be all to draw in\n",
      "that the match'd of the best in my private.\n",
      "\n",
      "BOT \n",
      "\n",
      "[23m 2s (1600 80%) train loss: 1.3634, test_loss: 1.4994]\n",
      "Which whose mask it were breatness: I will see him practise.\n",
      "But we have dead, unpose for me.\n",
      "\n",
      "Clown:\n",
      " \n",
      "\n",
      "[23m 45s (1650 82%) train loss: 1.3509, test_loss: 1.5462]\n",
      "What a pardon me with our thalp,\n",
      "And he shall not needs have wives to so honey; and\n",
      "you may not threat \n",
      "\n",
      "[24m 28s (1700 85%) train loss: 1.3869, test_loss: 1.5105]\n",
      "What are gives it will not could warrant man;\n",
      "My bones honest hand, sir, to did I plague your own valo \n",
      "\n",
      "[25m 11s (1750 87%) train loss: 1.3838, test_loss: 1.5011]\n",
      "What we have you no more, if you\n",
      "through your life.\n",
      "\n",
      "EMILIA:\n",
      "Well, I pray him, and the purposes by the \n",
      "\n",
      "[25m 54s (1800 90%) train loss: 1.3506, test_loss: 1.4743]\n",
      "When he grace I see your\n",
      "good and offend it.\n",
      "\n",
      "POLIXENES:\n",
      "Is it would ask! you were.\n",
      "\n",
      "First Lord:\n",
      "But w \n",
      "\n",
      "[26m 36s (1850 92%) train loss: 1.3648, test_loss: 1.4971]\n",
      "Who is not smiles!\n",
      "\n",
      "SICINIUS:\n",
      "Yea, so your rain?\n",
      "\n",
      "FALSTAFF:\n",
      "My great loss in her, must not they are my \n",
      "\n",
      "[27m 19s (1900 95%) train loss: 1.3581, test_loss: 1.4802]\n",
      "What we thirty? not what I am that the spirits and\n",
      "let me found them we may hearts,\n",
      "But 'Helice hath w \n",
      "\n",
      "[28m 2s (1950 97%) train loss: 1.3483, test_loss: 1.4819]\n",
      "Why spirit the gove a shepherd.\n",
      "\n",
      "CLIFFORD:\n",
      "What men but I say the hand and my blood,\n",
      "But tell me alrea \n",
      "\n",
      "[28m 44s (2000 100%) train loss: 1.3203, test_loss: 1.4649]\n",
      "What's, madam; stand is like a courtier:\n",
      "This valiant charm and interpression, sir,\n",
      "The mark'd his bor \n",
      "\n"
     ]
    }
   ],
   "source": [
    "rnn = RNN(n_characters, hidden_size, n_characters, model_type=model_type, n_layers=n_layers).to(device)\n",
    "rnn_optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "start = time.time()\n",
    "all_losses = []\n",
    "test_losses = []\n",
    "loss_avg = 0\n",
    "test_loss_avg = 0\n",
    "print(rnn)\n",
    "\n",
    "#n_epochs=1000\n",
    "print(\"Training for %d epochs...\" % n_epochs)\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    loss = train(rnn, *load_random_batch(train_text, chunk_len, batch_size), rnn_optimizer, criterion)\n",
    "    loss_avg += loss\n",
    "    \n",
    "    test_loss = eval_test(rnn, *load_random_batch(test_text, chunk_len, batch_size))\n",
    "    test_loss_avg += test_loss\n",
    "\n",
    "    if epoch % print_every == 0:\n",
    "        print('[%s (%d %d%%) train loss: %.4f, test_loss: %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss, test_loss))\n",
    "        print(generate(rnn, 'Wh', 100, device=device), '\\n')\n",
    "\n",
    "    if epoch % plot_every == 0:\n",
    "        all_losses.append(loss_avg / plot_every)\n",
    "        test_losses.append(test_loss_avg / plot_every)\n",
    "        loss_avg = 0\n",
    "        test_loss_avg = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K_pzABtVz13A"
   },
   "outputs": [],
   "source": [
    "# save network\n",
    "#torch.save(rnn.state_dict(), './rnn_generator.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 284
    },
    "executionInfo": {
     "elapsed": 559,
     "status": "ok",
     "timestamp": 1606204009019,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "lZoONrXPz13A",
    "outputId": "8909afea-ec15-4f19-edac-609438cbf3eb"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fd388607828>]"
      ]
     },
     "execution_count": 29,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(all_losses)\n",
    "plt.plot(test_losses, color='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1404,
     "status": "ok",
     "timestamp": 1606204009870,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "aNAbp2Aqz13A",
    "outputId": "fbe70816-edc4-4ede-83f7-03cffe154298"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The man wake him not hearing of his brother,\n",
      "Even it but the great sorrow and have them: and\n",
      "countrymen stands ill to seeming, the boy,\n",
      "I misa man cave his heads: therefore hath restamed\n",
      "thee my shadief in things and trust thou, for the faith\n",
      "She is very eater in here.\n",
      "\n",
      "ROSALINE:\n",
      "Give this made it in the lander: you\n",
      "see the day will necre 'gainst the heaps of heaven.\n",
      "\n",
      "MARGARET:\n",
      "I will take up the easy?\n",
      "\n",
      "CARDINAL WOLSEY:\n",
      "If your queen of the charess weep,\n",
      "And supportune shook all the wait with him, he's\n",
      "mine honest white high at all, which he small strength\n",
      "them to hear a prier get descends and the good that,\n",
      "With this place of my soul to secrets\n",
      "And shall seen mine reputation, I'll answer he\n",
      "made expled in our meaning strange plain.\n",
      "\n",
      "HOLOFERNES:\n",
      "Now, yet is not my husband on his admirally received\n",
      "And must be how Perchain, is ever masters to them as\n",
      "adversaries; which is to my sir, if you should\n",
      "here a villain the better to be directed and confess of\n",
      "the Frenchmen three suit; he is there\n"
     ]
    }
   ],
   "source": [
    "print(evaluate(rnn, prime_str='Th', predict_len=1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1cSYaORWz13A"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1474,
     "status": "ok",
     "timestamp": 1606238175208,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "qjIv9vg1z13A",
    "outputId": "942456cb-0ee3-4cff-c974-f225ab8a0eff"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file_len = 4728857\n",
      "train len:  4255971\n",
      "test len:  472886\n"
     ]
    }
   ],
   "source": [
    "#######################################################\n",
    "# Extra credit\n",
    "#######################################################\n",
    "all_characters = string.printable\n",
    "#print(all_characters)\n",
    "n_characters = len(all_characters)\n",
    "#print(n_characters)\n",
    "\n",
    "file_path = './charles_dickens.txt'\n",
    "file = unidecode.unidecode(open(file_path).read())\n",
    "file_len = len(file)\n",
    "print('file_len =', file_len)\n",
    "\n",
    "# we will leave the last 1/10th of text as test\n",
    "split = int(0.9*file_len)\n",
    "train_text = file[:split]\n",
    "test_text = file[split:]\n",
    "\n",
    "print('train len: ', len(train_text))\n",
    "print('test len: ', len(test_text))\n",
    "#print(train_text[0:100])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 201,
     "status": "ok",
     "timestamp": 1606238184249,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "ubMm3pkoOiKu",
    "outputId": "496c2779-0a27-4b2a-f951-3cc089fb0d42"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t on how you may.\"\n",
      "\n",
      "A ghost-seeing effect in Joe's own countenance informed me that Herbert\n",
      "had entered the room. So, I presented Joe to Herbert, who held out his\n",
      "hand; but Joe backed from it, and held\n"
     ]
    }
   ],
   "source": [
    "chunk_len = 200\n",
    "\n",
    "def random_chunk(text):\n",
    "    start_index = random.randint(0, len(text) - chunk_len)\n",
    "    end_index = start_index + chunk_len + 1\n",
    "    return text[start_index:end_index]\n",
    "\n",
    "print(random_chunk(train_text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 535045,
     "status": "ok",
     "timestamp": 1606238728365,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "ubJW_mQ7z13A",
    "outputId": "0b2e4e34-f548-46c4-d291-b53ed2be1223"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN(\n",
      "  (embedding_object): Embedding(100, 200)\n",
      "  (dn_object): GRU(200, 200)\n",
      "  (linear): Linear(in_features=200, out_features=100, bias=True)\n",
      ")\n",
      "Training for 750 epochs...\n",
      "[0m 34s (50 6%) train loss: 1.8457, test_loss: 1.9206]\n",
      "Whe bet to made sable and muscleser, ajing the gre to moth a mill heave in on oy no kning uppoo no to  \n",
      "\n",
      "[1m 8s (100 13%) train loss: 1.6645, test_loss: 1.7037]\n",
      "Which as\n",
      "towed thom all still a\n",
      "Teman, I such an of so the at Strunce would had ge no one words, and t \n",
      "\n",
      "[1m 43s (150 20%) train loss: 1.5787, test_loss: 1.6088]\n",
      "What the word, I done enchersevering breat stood him\n",
      "lidger--and I lempation, and the house he\n",
      "greatso \n",
      "\n",
      "[2m 17s (200 26%) train loss: 1.5320, test_loss: 1.5811]\n",
      "Whe was a cry, as to my eyes,\n",
      "and coming along this mother. Mr. Peggotty accuple as he face to the sug \n",
      "\n",
      "[2m 52s (250 33%) train loss: 1.4975, test_loss: 1.5915]\n",
      "Whe had a first me, with at\n",
      "hoarty, I had we early and his gentleman I got again. On the sated as you  \n",
      "\n",
      "[3m 26s (300 40%) train loss: 1.4718, test_loss: 1.5563]\n",
      "Where, that he would house-short was a place. 'He's a long some in her, and the remembround the procee \n",
      "\n",
      "[4m 1s (350 46%) train loss: 1.4789, test_loss: 1.5444]\n",
      "Whe, in the large mind, in a fatchened and moment, and I over the struck for it, and dual made the kin \n",
      "\n",
      "[4m 36s (400 53%) train loss: 1.4437, test_loss: 1.5138]\n",
      "Wheer, in his returned in that to keep him gave her so intent that she exceslamy that dealing of the b \n",
      "\n",
      "[5m 11s (450 60%) train loss: 1.4172, test_loss: 1.5463]\n",
      "When, as she could have  returned your down and good at a face.  The house it went of took hirst, capa \n",
      "\n",
      "[5m 46s (500 66%) train loss: 1.4354, test_loss: 1.4769]\n",
      "What I my libed his way, as was the gentleman you arating.\"\n",
      "\n",
      "\"No, which was a gras is not a\n",
      "such youse \n",
      "\n",
      "[6m 21s (550 73%) train loss: 1.4031, test_loss: 1.5110]\n",
      "Where to be approved yeart, that emoting the whole to in the stable course\n",
      "of the sulting to myself, a \n",
      "\n",
      "[6m 56s (600 80%) train loss: 1.3896, test_loss: 1.4864]\n",
      "What's imput ins there, and the grown with his is in a moment of the house, as a complete came at the  \n",
      "\n",
      "[7m 32s (650 86%) train loss: 1.4307, test_loss: 1.5016]\n",
      "When I must have been to be all I was got me, shall very look of the\n",
      "dear than to be usumed him to the \n",
      "\n",
      "[8m 8s (700 93%) train loss: 1.4032, test_loss: 1.4632]\n",
      "What is so made by\n",
      "unduring, after a little bettingly.  That was still to devious the us. So me he had \n",
      "\n",
      "[8m 44s (750 100%) train loss: 1.4123, test_loss: 1.4549]\n",
      "Where, old writion as you knows. A good neeares of the please, I promissonal do you assumed them to be \n",
      "\n"
     ]
    }
   ],
   "source": [
    "rnn = RNN(n_characters, hidden_size, n_characters, model_type=model_type, n_layers=n_layers).to(device)\n",
    "rnn_optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "start = time.time()\n",
    "all_losses = []\n",
    "test_losses = []\n",
    "loss_avg = 0\n",
    "test_loss_avg = 0\n",
    "print(rnn)\n",
    "\n",
    "#n_epochs=1000\n",
    "print(\"Training for %d epochs...\" % n_epochs)\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    loss = train(rnn, *load_random_batch(train_text, chunk_len, batch_size), rnn_optimizer, criterion)\n",
    "    loss_avg += loss\n",
    "    \n",
    "    test_loss = eval_test(rnn, *load_random_batch(test_text, chunk_len, batch_size))\n",
    "    test_loss_avg += test_loss\n",
    "\n",
    "    if epoch % print_every == 0:\n",
    "        print('[%s (%d %d%%) train loss: %.4f, test_loss: %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss, test_loss))\n",
    "        print(generate(rnn, 'Wh', 100, device=device), '\\n')\n",
    "\n",
    "    if epoch % plot_every == 0:\n",
    "        all_losses.append(loss_avg / plot_every)\n",
    "        test_losses.append(test_loss_avg / plot_every)\n",
    "        loss_avg = 0\n",
    "        test_loss_avg = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 284
    },
    "executionInfo": {
     "elapsed": 441,
     "status": "ok",
     "timestamp": 1606238789430,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "F6RUCJO-z13A",
    "outputId": "e0ba8a14-38f1-4a5d-a171-6bafcb6d0782"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fece206dc50>]"
      ]
     },
     "execution_count": 17,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(all_losses)\n",
    "plt.plot(test_losses, color='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1119,
     "status": "ok",
     "timestamp": 1606238793308,
     "user": {
      "displayName": "Punit Jha",
      "photoUrl": "",
      "userId": "07885534541681120711"
     },
     "user_tz": 360
    },
    "id": "gZRziBDwz13A",
    "outputId": "fbac23e5-a76a-47bc-9bd2-37ccc02cdd70"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is, and pulled to bear when he saw, and she was then or a point of Mr. Sumble, as aspection, were no mading her handler stood for the stoods pundsact.\n",
      "\n",
      "\"Day!\"\n",
      "\n",
      "\"I want to you the most in you to see?\" said Mr. Lorry. She had been offer\n",
      "with mentions of lawburing his face. He had bad and with me to her hands to be to set of thrify found impury for it what leading in her against suffed people.\"\n",
      "\n",
      "I saw, and paused it hurry, stood Defarge; \"am this, she was the supporwaid for good stand, to have you wish is a gently enlangeasure in the controod so he proceeding of the Clara, and the piers, where me to be twomouson was never\n",
      "appeared\n",
      "to have a beginner, as the likely to the indived into the table instate at her feel on the possible upon his astreesers (I want to the little touched to her fast came and\n",
      "graught,\" returned him of me,\" replied the word\n",
      "with his last were pursuit. At the ainher. And mewhere she had relieved the coversame, and the kitches and said Joe\n",
      "was find a warms and wants repa\n"
     ]
    }
   ],
   "source": [
    "print(evaluate(rnn, prime_str='Is', predict_len=1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "juvJ_3BhOtLV"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "anaconda-cloud": {},
  "colab": {
   "collapsed_sections": [],
   "name": "MP4_P2_generation.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}